{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "CORDS_SL_CIFAR10_Custom_Train.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "d17ddaa65d464c9ca0823c9af032f7d4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_7a304fee4b7544e09834f0b88cc8ed0d",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_c563aa40e8c3470cbb249f9a05f6104a",
              "IPY_MODEL_fc43cbc57afb465182002ae7f1a56e64",
              "IPY_MODEL_09b777c37b954c8390ca3a87b148e2f0"
            ]
          }
        },
        "7a304fee4b7544e09834f0b88cc8ed0d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "c563aa40e8c3470cbb249f9a05f6104a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_a85d8d0e18e444d7b2f2a81a42483825",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": "",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_2b7cfdb333634389922ceb12df38f8d8"
          }
        },
        "fc43cbc57afb465182002ae7f1a56e64": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_7b8fbd10065e4d9880f53525e686bb4b",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 170498071,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 170498071,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_c89610fbe1564b06b4516226638613ad"
          }
        },
        "09b777c37b954c8390ca3a87b148e2f0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_2d9dac0aaa93491b9cae7922fc8e8981",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 170499072/? [00:11&lt;00:00, 16432977.99it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_8840f48e59254aa3b14a6416d299bf45"
          }
        },
        "a85d8d0e18e444d7b2f2a81a42483825": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "2b7cfdb333634389922ceb12df38f8d8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "7b8fbd10065e4d9880f53525e686bb4b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "c89610fbe1564b06b4516226638613ad": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "2d9dac0aaa93491b9cae7922fc8e8981": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "8840f48e59254aa3b14a6416d299bf45": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nRBrJb8I_vUv"
      },
      "source": [
        "# Cloning CORDS repository"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "x35Mfc-RnKkX",
        "outputId": "d8a4cdb6-9def-4155-9438-3f4da853a39e"
      },
      "source": [
        "!git clone https://github.com/decile-team/cords.git\n",
        "%cd cords/\n",
        "%ls"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cloning into 'cords'...\n",
            "remote: Enumerating objects: 3078, done.\u001b[K\n",
            "remote: Counting objects: 100% (1700/1700), done.\u001b[K\n",
            "remote: Compressing objects: 100% (782/782), done.\u001b[K\n",
            "remote: Total 3078 (delta 1087), reused 1466 (delta 898), pack-reused 1378\u001b[K\n",
            "Receiving objects: 100% (3078/3078), 52.84 MiB | 17.36 MiB/s, done.\n",
            "Resolving deltas: 100% (1824/1824), done.\n",
            "/content/cords\n",
            "\u001b[0m\u001b[01;34mbenchmarks\u001b[0m/  \u001b[01;34mexamples\u001b[0m/       \u001b[01;34mrequirements\u001b[0m/        run_ssl.py   train_ssl.py\n",
            "\u001b[01;34mconfigs\u001b[0m/     LICENSE.txt     requirements.txt     setup.py\n",
            "\u001b[01;34mcords\u001b[0m/       paramtuning.py  run_param_tuning.py  \u001b[01;34mtests\u001b[0m/\n",
            "\u001b[01;34mdocs\u001b[0m/        README.md       run_sl.py            train_sl.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gAA3K0cVnyd9"
      },
      "source": [
        "# Install prerequisite libraries of CORDS"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6CXZ4L1ynmcp",
        "outputId": "cc3e0d52-0609-4248-dac3-60d1caeb71f8"
      },
      "source": [
        "!pip install dotmap\n",
        "!pip install apricot-select\n",
        "!pip install ray[default]\n",
        "!pip install ray[tune]"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting dotmap\n",
            "  Downloading dotmap-1.3.24-py3-none-any.whl (11 kB)\n",
            "Installing collected packages: dotmap\n",
            "Successfully installed dotmap-1.3.24\n",
            "Collecting apricot-select\n",
            "  Downloading apricot-select-0.6.1.tar.gz (28 kB)\n",
            "Requirement already satisfied: numpy>=1.14.2 in /usr/local/lib/python3.7/dist-packages (from apricot-select) (1.19.5)\n",
            "Requirement already satisfied: scipy>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from apricot-select) (1.4.1)\n",
            "Requirement already satisfied: numba>=0.43.0 in /usr/local/lib/python3.7/dist-packages (from apricot-select) (0.51.2)\n",
            "Requirement already satisfied: tqdm>=4.24.0 in /usr/local/lib/python3.7/dist-packages (from apricot-select) (4.62.3)\n",
            "Collecting nose\n",
            "  Downloading nose-1.3.7-py3-none-any.whl (154 kB)\n",
            "\u001b[K     |████████████████████████████████| 154 kB 5.5 MB/s \n",
            "\u001b[?25hRequirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba>=0.43.0->apricot-select) (57.4.0)\n",
            "Requirement already satisfied: llvmlite<0.35,>=0.34.0.dev0 in /usr/local/lib/python3.7/dist-packages (from numba>=0.43.0->apricot-select) (0.34.0)\n",
            "Building wheels for collected packages: apricot-select\n",
            "  Building wheel for apricot-select (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for apricot-select: filename=apricot_select-0.6.1-py3-none-any.whl size=48787 sha256=e8933f95e093effcc31c07758debf5144a5193270b434344798086a5c4000a8d\n",
            "  Stored in directory: /root/.cache/pip/wheels/1d/b0/5d/41bab30f23d17864700963dad70bbeda159a409e94f0778f2f\n",
            "Successfully built apricot-select\n",
            "Installing collected packages: nose, apricot-select\n",
            "Successfully installed apricot-select-0.6.1 nose-1.3.7\n",
            "Collecting ray[default]\n",
            "  Downloading ray-1.7.0-cp37-cp37m-manylinux2014_x86_64.whl (54.0 MB)\n",
            "\u001b[K     |████████████████████████████████| 54.0 MB 89 kB/s \n",
            "\u001b[?25hRequirement already satisfied: protobuf>=3.15.3 in /usr/local/lib/python3.7/dist-packages (from ray[default]) (3.17.3)\n",
            "Requirement already satisfied: attrs in /usr/local/lib/python3.7/dist-packages (from ray[default]) (21.2.0)\n",
            "Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.7/dist-packages (from ray[default]) (7.1.2)\n",
            "Collecting redis>=3.5.0\n",
            "  Downloading redis-3.5.3-py2.py3-none-any.whl (72 kB)\n",
            "\u001b[K     |████████████████████████████████| 72 kB 628 kB/s \n",
            "\u001b[?25hRequirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from ray[default]) (3.13)\n",
            "Requirement already satisfied: grpcio>=1.28.1 in /usr/local/lib/python3.7/dist-packages (from ray[default]) (1.41.0)\n",
            "Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ray[default]) (1.0.2)\n",
            "Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from ray[default]) (1.19.5)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from ray[default]) (3.3.0)\n",
            "Requirement already satisfied: jsonschema in /usr/local/lib/python3.7/dist-packages (from ray[default]) (2.6.0)\n",
            "Collecting opencensus\n",
            "  Downloading opencensus-0.8.0-py2.py3-none-any.whl (128 kB)\n",
            "\u001b[K     |████████████████████████████████| 128 kB 58.4 MB/s \n",
            "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from ray[default]) (2.23.0)\n",
            "Collecting colorful\n",
            "  Downloading colorful-0.5.4-py2.py3-none-any.whl (201 kB)\n",
            "\u001b[K     |████████████████████████████████| 201 kB 56.7 MB/s \n",
            "\u001b[?25hRequirement already satisfied: prometheus-client>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from ray[default]) (0.11.0)\n",
            "Collecting py-spy>=0.2.0\n",
            "  Downloading py_spy-0.3.10-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (3.2 MB)\n",
            "\u001b[K     |████████████████████████████████| 3.2 MB 55.2 MB/s \n",
            "\u001b[?25hCollecting aioredis<2\n",
            "  Downloading aioredis-1.3.1-py3-none-any.whl (65 kB)\n",
            "\u001b[K     |████████████████████████████████| 65 kB 3.3 MB/s \n",
            "\u001b[?25hCollecting aiohttp\n",
            "  Downloading aiohttp-3.7.4.post0-cp37-cp37m-manylinux2014_x86_64.whl (1.3 MB)\n",
            "\u001b[K     |████████████████████████████████| 1.3 MB 40.4 MB/s \n",
            "\u001b[?25hCollecting gpustat\n",
            "  Downloading gpustat-0.6.0.tar.gz (78 kB)\n",
            "\u001b[K     |████████████████████████████████| 78 kB 6.3 MB/s \n",
            "\u001b[?25hCollecting aiohttp-cors\n",
            "  Downloading aiohttp_cors-0.7.0-py3-none-any.whl (27 kB)\n",
            "Collecting hiredis\n",
            "  Downloading hiredis-2.0.0-cp37-cp37m-manylinux2010_x86_64.whl (85 kB)\n",
            "\u001b[K     |████████████████████████████████| 85 kB 4.6 MB/s \n",
            "\u001b[?25hCollecting async-timeout\n",
            "  Downloading async_timeout-3.0.1-py3-none-any.whl (8.2 kB)\n",
            "Requirement already satisfied: six>=1.5.2 in /usr/local/lib/python3.7/dist-packages (from grpcio>=1.28.1->ray[default]) (1.15.0)\n",
            "Collecting multidict<7.0,>=4.5\n",
            "  Downloading multidict-5.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (160 kB)\n",
            "\u001b[K     |████████████████████████████████| 160 kB 61.3 MB/s \n",
            "\u001b[?25hCollecting yarl<2.0,>=1.0\n",
            "  Downloading yarl-1.7.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB)\n",
            "\u001b[K     |████████████████████████████████| 271 kB 58.3 MB/s \n",
            "\u001b[?25hRequirement already satisfied: typing-extensions>=3.6.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->ray[default]) (3.7.4.3)\n",
            "Requirement already satisfied: chardet<5.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->ray[default]) (3.0.4)\n",
            "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.7/dist-packages (from yarl<2.0,>=1.0->aiohttp->ray[default]) (2.10)\n",
            "Requirement already satisfied: nvidia-ml-py3>=7.352.0 in /usr/local/lib/python3.7/dist-packages (from gpustat->ray[default]) (7.352.0)\n",
            "Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from gpustat->ray[default]) (5.4.8)\n",
            "Collecting blessings>=1.6\n",
            "  Downloading blessings-1.7-py3-none-any.whl (18 kB)\n",
            "Requirement already satisfied: google-api-core<3.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from opencensus->ray[default]) (1.26.3)\n",
            "Collecting opencensus-context==0.1.2\n",
            "  Downloading opencensus_context-0.1.2-py2.py3-none-any.whl (4.4 kB)\n",
            "Requirement already satisfied: pytz in /usr/local/lib/python3.7/dist-packages (from google-api-core<3.0.0,>=1.0.0->opencensus->ray[default]) (2018.9)\n",
            "Requirement already satisfied: setuptools>=40.3.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<3.0.0,>=1.0.0->opencensus->ray[default]) (57.4.0)\n",
            "Requirement already satisfied: google-auth<2.0dev,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from google-api-core<3.0.0,>=1.0.0->opencensus->ray[default]) (1.35.0)\n",
            "Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<3.0.0,>=1.0.0->opencensus->ray[default]) (1.53.0)\n",
            "Requirement already satisfied: packaging>=14.3 in /usr/local/lib/python3.7/dist-packages (from google-api-core<3.0.0,>=1.0.0->opencensus->ray[default]) (21.0)\n",
            "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<2.0dev,>=1.21.1->google-api-core<3.0.0,>=1.0.0->opencensus->ray[default]) (0.2.8)\n",
            "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<2.0dev,>=1.21.1->google-api-core<3.0.0,>=1.0.0->opencensus->ray[default]) (4.7.2)\n",
            "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<2.0dev,>=1.21.1->google-api-core<3.0.0,>=1.0.0->opencensus->ray[default]) (4.2.4)\n",
            "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=14.3->google-api-core<3.0.0,>=1.0.0->opencensus->ray[default]) (2.4.7)\n",
            "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2.0dev,>=1.21.1->google-api-core<3.0.0,>=1.0.0->opencensus->ray[default]) (0.4.8)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->ray[default]) (2021.5.30)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->ray[default]) (1.24.3)\n",
            "Building wheels for collected packages: gpustat\n",
            "  Building wheel for gpustat (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for gpustat: filename=gpustat-0.6.0-py3-none-any.whl size=12617 sha256=43e65a4da72cecf5e3d522ba9a80bb7029cae5d37ae58af997e1a72efe60e11a\n",
            "  Stored in directory: /root/.cache/pip/wheels/e6/67/af/f1ad15974b8fd95f59a63dbf854483ebe5c7a46a93930798b8\n",
            "Successfully built gpustat\n",
            "Installing collected packages: multidict, yarl, async-timeout, redis, opencensus-context, hiredis, blessings, aiohttp, ray, py-spy, opencensus, gpustat, colorful, aioredis, aiohttp-cors\n",
            "Successfully installed aiohttp-3.7.4.post0 aiohttp-cors-0.7.0 aioredis-1.3.1 async-timeout-3.0.1 blessings-1.7 colorful-0.5.4 gpustat-0.6.0 hiredis-2.0.0 multidict-5.2.0 opencensus-0.8.0 opencensus-context-0.1.2 py-spy-0.3.10 ray-1.7.0 redis-3.5.3 yarl-1.7.0\n",
            "Requirement already satisfied: ray[tune] in /usr/local/lib/python3.7/dist-packages (1.7.0)\n",
            "Requirement already satisfied: attrs in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (21.2.0)\n",
            "Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (1.0.2)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (3.3.0)\n",
            "Requirement already satisfied: protobuf>=3.15.3 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (3.17.3)\n",
            "Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (3.13)\n",
            "Requirement already satisfied: grpcio>=1.28.1 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (1.41.0)\n",
            "Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (7.1.2)\n",
            "Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (1.19.5)\n",
            "Requirement already satisfied: redis>=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (3.5.3)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (2.23.0)\n",
            "Collecting tensorboardX>=1.9\n",
            "  Downloading tensorboardX-2.4-py2.py3-none-any.whl (124 kB)\n",
            "\u001b[K     |████████████████████████████████| 124 kB 3.9 MB/s \n",
            "\u001b[?25hRequirement already satisfied: tabulate in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (0.8.9)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from ray[tune]) (1.1.5)\n",
            "Requirement already satisfied: six>=1.5.2 in /usr/local/lib/python3.7/dist-packages (from grpcio>=1.28.1->ray[tune]) (1.15.0)\n",
            "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->ray[tune]) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->ray[tune]) (2018.9)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->ray[tune]) (3.0.4)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->ray[tune]) (2021.5.30)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->ray[tune]) (2.10)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->ray[tune]) (1.24.3)\n",
            "Installing collected packages: tensorboardX\n",
            "Successfully installed tensorboardX-2.4\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gzyCsbnJn3_L"
      },
      "source": [
        "#Import necessary libraries"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GazQv21pjIKd"
      },
      "source": [
        "import time\n",
        "import numpy as np\n",
        "import os\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from cords.utils.data.datasets.SL import gen_dataset\n",
        "from torch.utils.data import Subset\n",
        "from cords.utils.config_utils import load_config_data\n",
        "import os.path as osp\n",
        "from cords.utils.data.data_utils import WeightedSubset\n",
        "from ray import tune"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MyMJdoeqok49"
      },
      "source": [
        "#Loading the CIFAR10 dataset\n",
        "\n",
        "Since CIFAR10 dataset is a predefined dataset in CORDS repository. You can use the gen_dataset function for loading the CIFAR10 dataset.\n",
        "\n",
        "**Input parameters of gen_dataset function:**\n",
        "\n",
        "***datadir :*** Directory containing the data. If data is not downloaded, then data will be automatically downloaded into the mentioned directory path.\n",
        "\n",
        "***dset_name :*** Dataset Name\n",
        "\n",
        "***feature :*** If \"classimb\", we make the dataset inherently imbalanced.\n",
        "          If \"classimb\", we make the dataset labels noisy.\n",
        "          If None, we return the standard datasets.\n",
        "\n",
        "***isnumpy :*** If True, return dataset in the numpy array format.\n",
        "          If False, return dataset in torch dataset format.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 100,
          "referenced_widgets": [
            "d17ddaa65d464c9ca0823c9af032f7d4",
            "7a304fee4b7544e09834f0b88cc8ed0d",
            "c563aa40e8c3470cbb249f9a05f6104a",
            "fc43cbc57afb465182002ae7f1a56e64",
            "09b777c37b954c8390ca3a87b148e2f0",
            "a85d8d0e18e444d7b2f2a81a42483825",
            "2b7cfdb333634389922ceb12df38f8d8",
            "7b8fbd10065e4d9880f53525e686bb4b",
            "c89610fbe1564b06b4516226638613ad",
            "2d9dac0aaa93491b9cae7922fc8e8981",
            "8840f48e59254aa3b14a6416d299bf45"
          ]
        },
        "id": "rkjRkzs2olSD",
        "outputId": "80f93477-b325-4e64-dbeb-b1f9c3b6b1d8"
      },
      "source": [
        "trainset, validset, testset, num_cls = gen_dataset('data/', 'cifar10', None, isnumpy=False)\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d17ddaa65d464c9ca0823c9af032f7d4",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "  0%|          | 0/170498071 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Extracting data/cifar-10-python.tar.gz to data/\n",
            "Files already downloaded and verified\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lnL-sve1qrnP"
      },
      "source": [
        "# Create dataloaders"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QnvXqaGbqnhH"
      },
      "source": [
        "trn_batch_size = 20\n",
        "val_batch_size = 20\n",
        "tst_batch_size = 1000\n",
        "\n",
        "# Creating the Data Loaders\n",
        "trainloader = torch.utils.data.DataLoader(trainset, batch_size=trn_batch_size,\n",
        "                                          shuffle=False, pin_memory=True)\n",
        "\n",
        "valloader = torch.utils.data.DataLoader(validset, batch_size=val_batch_size,\n",
        "                                        shuffle=False, pin_memory=True)\n",
        "\n",
        "testloader = torch.utils.data.DataLoader(testset, batch_size=tst_batch_size,\n",
        "                                          shuffle=False, pin_memory=True)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WGI8vbIF1IOS"
      },
      "source": [
        "#Defining Model\n",
        "\n",
        "CORDS has a set of predefined models bulit in utils folder. You can import them directly."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "f97m03ZbqvNK"
      },
      "source": [
        "from cords.utils.models import ResNet18\n",
        "numclasses = 10\n",
        "device = 'cuda' #Device Argument\n",
        "model = ResNet18(10)\n",
        "model = model.to(device)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H0VxOIx31O4X"
      },
      "source": [
        "# Defining Loss Functions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "h2oAaNkn1OhK"
      },
      "source": [
        "criterion = nn.CrossEntropyLoss()\n",
        "criterion_nored = nn.CrossEntropyLoss(reduction='none')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vM80Iyj76mIM"
      },
      "source": [
        "# Checkpointing Utility functions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vg12aEnb6hmf"
      },
      "source": [
        "def save_ckpt(state, ckpt_path):\n",
        "    torch.save(state, ckpt_path)\n",
        "\n",
        "\n",
        "def load_ckpt(ckpt_path, model, optimizer):\n",
        "    checkpoint = torch.load(ckpt_path)\n",
        "    start_epoch = checkpoint['epoch']\n",
        "    model.load_state_dict(checkpoint['state_dict'])\n",
        "    optimizer.load_state_dict(checkpoint['optimizer'])\n",
        "    loss = checkpoint['loss']\n",
        "    metrics = checkpoint['metrics']\n",
        "    return start_epoch, model, optimizer, loss, metrics\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AmcpTO2tBNPb"
      },
      "source": [
        "# Cumulative time calculation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xapglQXFBM5h"
      },
      "source": [
        "def generate_cumulative_timing(mod_timing):\n",
        "    tmp = 0\n",
        "    mod_cum_timing = np.zeros(len(mod_timing))\n",
        "    for i in range(len(mod_timing)):\n",
        "        tmp += mod_timing[i]\n",
        "        mod_cum_timing[i] = tmp\n",
        "    return mod_cum_timing / 3600\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U9DeC_LZ2MfB"
      },
      "source": [
        "# Defining Optimizers and schedulers"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "14_R6HQT2L0I"
      },
      "source": [
        "optimizer = optim.SGD(model.parameters(), lr=1e-2,\n",
        "                                  momentum=0.9,\n",
        "                                  weight_decay=5e-4,\n",
        "                                  nesterov=False)\n",
        "\n",
        "#T_max is the maximum number of scheduler steps. Here we are using the number of epochs as the maximum number of scheduler steps.\n",
        "\n",
        "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,\n",
        "                                                       T_max=300) \n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uYOP5EWU_UD7"
      },
      "source": [
        "#Get logger object for logging"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LZ8H3_Hx_TtF"
      },
      "source": [
        "def __get_logger(results_dir):\n",
        "  os.makedirs(results_dir, exist_ok=True)\n",
        "  # setup logger\n",
        "  plain_formatter = logging.Formatter(\"[%(asctime)s] %(name)s %(levelname)s: %(message)s\",\n",
        "                                      datefmt=\"%m/%d %H:%M:%S\")\n",
        "  logger = logging.getLogger(__name__)\n",
        "  logger.setLevel(logging.INFO)\n",
        "  s_handler = logging.StreamHandler(stream=sys.stdout)\n",
        "  s_handler.setFormatter(plain_formatter)\n",
        "  s_handler.setLevel(logging.INFO)\n",
        "  logger.addHandler(s_handler)\n",
        "  f_handler = logging.FileHandler(os.path.join(results_dir, \"results.log\"))\n",
        "  f_handler.setFormatter(plain_formatter)\n",
        "  f_handler.setLevel(logging.DEBUG)\n",
        "  logger.addHandler(f_handler)\n",
        "  logger.propagate = False\n",
        "  return logger\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MgQ2I6bDyMwN"
      },
      "source": [
        "#Instantiating logger file for logging the information"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "k_IvqZDtyMg9"
      },
      "source": [
        "import logging\n",
        "import os\n",
        "import os.path as osp\n",
        "import sys\n",
        "\n",
        "#Results logging directory\n",
        "results_dir = osp.abspath(osp.expanduser('results'))\n",
        "logger = __get_logger(results_dir)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NuYxvrYm-v1Z",
        "outputId": "5de8e278-16ca-42b0-9ceb-7ca3508a891f"
      },
      "source": [
        "logger.info(\"hello\")"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[10/19 19:28:17] __main__ INFO: hello\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vq_ehn_0vPjZ"
      },
      "source": [
        "# Instantiating GLISTER subset selection dataloaders\n",
        "We instantiate subset dataloaders that can be used for training the models with adaptive subsets.\n",
        "\n",
        "Each subset dataloader needs data selection strategy arguments in the form of a dotmap dictionary, logger and dataloader specific arguments like batch size, shuffle etc.\n",
        "\n",
        "We are instantiating GLISTER dataloader here with no warm start. But any dataloader can be instantiated in the same way by passing the required arguments\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8TNMpF36xykF",
        "outputId": "e5a084e0-dfd9-4228-81e5-3f8fd96fbf94"
      },
      "source": [
        "from cords.utils.data.dataloader.SL.adaptive import GLISTERDataLoader, OLRandomDataLoader, \\\n",
        "    CRAIGDataLoader, GradMatchDataLoader, RandomDataLoader\n",
        "from dotmap import DotMap\n",
        "\n",
        "selection_strategy = 'GLISTER'\n",
        "dss_args = dict(model=model,\n",
        "                loss=criterion_nored,\n",
        "                eta=0.01,\n",
        "                num_classes=10,\n",
        "                num_epochs=300,\n",
        "                device='cuda',\n",
        "                fraction=0.1,\n",
        "                select_every=20,\n",
        "                kappa=0,\n",
        "                linear_layer=False,\n",
        "                selection_type='SL',\n",
        "                greedy='Stochastic')\n",
        "dss_args = DotMap(dss_args)\n",
        "\n",
        "dataloader = GLISTERDataLoader(trainloader, valloader, dss_args, logger, \n",
        "                                  batch_size=20, \n",
        "                                  shuffle=True,\n",
        "                                  pin_memory=False)\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/numba/np/ufunc/parallel.py:363: NumbaWarning: \u001b[1mThe TBB threading layer requires TBB version 2019.5 or later i.e., TBB_INTERFACE_VERSION >= 11005. Found TBB_INTERFACE_VERSION = 9107. The TBB threading layer is disabled.\u001b[0m\n",
            "  warnings.warn(problem)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MFmFiQ7u3czc"
      },
      "source": [
        "# Additional arguments for training, evaluation and checkpointing"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gCjZtiAs3cel"
      },
      "source": [
        "#Training Arguments\n",
        "num_epochs = 300\n",
        "\n",
        "#Arguments for results logging\n",
        "print_every = 1\n",
        "print_args = [\"val_loss\", \"val_acc\", \"tst_loss\", \"tst_acc\", \"time\"]\n",
        "\n",
        "#Argumets for checkpointing\n",
        "save_every = 20\n",
        "is_save = True\n",
        "\n",
        "#Evaluation Metrics\n",
        "trn_losses = list()\n",
        "val_losses = list()\n",
        "tst_losses = list()\n",
        "subtrn_losses = list()\n",
        "timing = list()\n",
        "trn_acc = list()\n",
        "val_acc = list()  \n",
        "tst_acc = list()  \n",
        "subtrn_acc = list()\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tGF82RIQzeLK"
      },
      "source": [
        "#Custom Training loop with evaluation\n",
        "\n",
        "Subset dataloader returns data samples, labels and associated weights with each data sample. Hence, inorder to incorporate the weights in the dataloader into the training loop, we use a **loss function**  with **reduction='none'** to get per-sample loss values. Then we calculate the weighted average of batch losses using the following code snippet:\n",
        "\n",
        "`loss = torch.dot(losses, weights/(weights.sum()))`\n",
        "\n",
        "---\n",
        "**NOTE**\n",
        "\n",
        "### If you want to implement a custom training loop, please note that the subset dataloaders also returns additional weight parameter for each data sample.\n",
        "---"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BAl87AX7zwUX",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "fc7d43e3-cfa2-4002-98be-91c744d6e297"
      },
      "source": [
        "\"\"\"\n",
        "################################################# Training Loop #################################################\n",
        "\"\"\"\n",
        "for epoch in range(num_epochs):\n",
        "    subtrn_loss = 0\n",
        "    subtrn_correct = 0\n",
        "    subtrn_total = 0\n",
        "    model.train()\n",
        "    start_time = time.time()\n",
        "    for _, (inputs, targets, weights) in enumerate(dataloader):\n",
        "        inputs = inputs.to(device)\n",
        "        targets = targets.to(device, non_blocking=True)\n",
        "        weights = weights.to(device)  \n",
        "        optimizer.zero_grad()\n",
        "        outputs = model(inputs)\n",
        "        losses = criterion_nored(outputs, targets)\n",
        "        loss = torch.dot(losses, weights/(weights.sum()))\n",
        "        loss.backward()\n",
        "        subtrn_loss += loss.item()\n",
        "        optimizer.step()\n",
        "        _, predicted = outputs.max(1)\n",
        "        subtrn_total += targets.size(0)\n",
        "        subtrn_correct += predicted.eq(targets).sum().item()\n",
        "    epoch_time = time.time() - start_time\n",
        "    scheduler.step()\n",
        "    timing.append(epoch_time)\n",
        "\n",
        "\n",
        "    \"\"\"\n",
        "    ################################################# Evaluation Loop #################################################\n",
        "    \"\"\"\n",
        "\n",
        "    if (epoch + 1) % print_every == 0:\n",
        "        trn_loss = 0\n",
        "        trn_correct = 0\n",
        "        trn_total = 0\n",
        "        val_loss = 0\n",
        "        val_correct = 0\n",
        "        val_total = 0\n",
        "        tst_correct = 0\n",
        "        tst_total = 0\n",
        "        tst_loss = 0\n",
        "        model.eval()\n",
        "\n",
        "        if (\"trn_loss\" in print_args) or (\"trn_acc\" in print_args):\n",
        "            with torch.no_grad():\n",
        "                for _, (inputs, targets) in enumerate(trainloader):\n",
        "                    inputs, targets = inputs.to(device), \\\n",
        "                                      targets.to(device, non_blocking=True)\n",
        "                    outputs = model(inputs)\n",
        "                    loss = criterion(outputs, targets)\n",
        "                    trn_loss += loss.item()\n",
        "                    if \"trn_acc\" in print_args:\n",
        "                        _, predicted = outputs.max(1)\n",
        "                        trn_total += targets.size(0)\n",
        "                        trn_correct += predicted.eq(targets).sum().item()\n",
        "                trn_losses.append(trn_loss)\n",
        "\n",
        "            if \"trn_acc\" in print_args:\n",
        "                trn_acc.append(trn_correct / trn_total)\n",
        "\n",
        "        if (\"val_loss\" in print_args) or (\"val_acc\" in print_args):\n",
        "            with torch.no_grad():\n",
        "                for _, (inputs, targets) in enumerate(valloader):\n",
        "                    inputs, targets = inputs.to(device), \\\n",
        "                                      targets.to(device, non_blocking=True)\n",
        "                    outputs = model(inputs)\n",
        "                    loss = criterion(outputs, targets)\n",
        "                    val_loss += loss.item()\n",
        "                    if \"val_acc\" in print_args:\n",
        "                        _, predicted = outputs.max(1)\n",
        "                        val_total += targets.size(0)\n",
        "                        val_correct += predicted.eq(targets).sum().item()\n",
        "                val_losses.append(val_loss)\n",
        "\n",
        "            if \"val_acc\" in print_args:\n",
        "                val_acc.append(val_correct / val_total)\n",
        "\n",
        "        if (\"tst_loss\" in print_args) or (\"tst_acc\" in print_args):\n",
        "            with torch.no_grad():\n",
        "                for _, (inputs, targets) in enumerate(testloader):\n",
        "                    inputs, targets = inputs.to(device), \\\n",
        "                                      targets.to(device, non_blocking=True)\n",
        "                    outputs = model(inputs)\n",
        "                    loss = criterion(outputs, targets)\n",
        "                    tst_loss += loss.item()\n",
        "                    if \"tst_acc\" in print_args:\n",
        "                        _, predicted = outputs.max(1)\n",
        "                        tst_total += targets.size(0)\n",
        "                        tst_correct += predicted.eq(targets).sum().item()\n",
        "                tst_losses.append(tst_loss)\n",
        "\n",
        "            if \"tst_acc\" in print_args:\n",
        "                tst_acc.append(tst_correct / tst_total)\n",
        "\n",
        "        if \"subtrn_acc\" in print_args:\n",
        "            subtrn_acc.append(subtrn_correct / subtrn_total)\n",
        "\n",
        "        if \"subtrn_losses\" in print_args:\n",
        "            subtrn_losses.append(subtrn_loss)\n",
        "\n",
        "        print_str = \"Epoch: \" + str(epoch + 1)\n",
        "\n",
        "        \"\"\"\n",
        "        ################################################# Results Printing #################################################\n",
        "        \"\"\"\n",
        "\n",
        "        for arg in print_args:\n",
        "\n",
        "            if arg == \"val_loss\":\n",
        "                print_str += \" , \" + \"Validation Loss: \" + str(val_losses[-1])\n",
        "\n",
        "            if arg == \"val_acc\":\n",
        "                print_str += \" , \" + \"Validation Accuracy: \" + str(val_acc[-1])\n",
        "\n",
        "            if arg == \"tst_loss\":\n",
        "                print_str += \" , \" + \"Test Loss: \" + str(tst_losses[-1])\n",
        "\n",
        "            if arg == \"tst_acc\":\n",
        "                print_str += \" , \" + \"Test Accuracy: \" + str(tst_acc[-1])\n",
        "\n",
        "            if arg == \"trn_loss\":\n",
        "                print_str += \" , \" + \"Training Loss: \" + str(trn_losses[-1])\n",
        "\n",
        "            if arg == \"trn_acc\":\n",
        "                print_str += \" , \" + \"Training Accuracy: \" + str(trn_acc[-1])\n",
        "\n",
        "            if arg == \"subtrn_loss\":\n",
        "                print_str += \" , \" + \"Subset Loss: \" + str(subtrn_losses[-1])\n",
        "\n",
        "            if arg == \"subtrn_acc\":\n",
        "                print_str += \" , \" + \"Subset Accuracy: \" + str(subtrn_acc[-1])\n",
        "\n",
        "            if arg == \"time\":\n",
        "                print_str += \" , \" + \"Timing: \" + str(timing[-1])\n",
        "\n",
        "        logger.info(print_str)\n",
        "\n",
        "    \"\"\"\n",
        "    ################################################# Checkpoint Saving #################################################\n",
        "    \"\"\"\n",
        "\n",
        "    if ((epoch + 1) % save_every == 0) and is_save:\n",
        "\n",
        "        metric_dict = {}\n",
        "\n",
        "        for arg in print_args:\n",
        "            if arg == \"val_loss\":\n",
        "                metric_dict['val_loss'] = val_losses\n",
        "            if arg == \"val_acc\":\n",
        "                metric_dict['val_acc'] = val_acc\n",
        "            if arg == \"tst_loss\":\n",
        "                metric_dict['tst_loss'] = tst_losses\n",
        "            if arg == \"tst_acc\":\n",
        "                metric_dict['tst_acc'] = tst_acc\n",
        "            if arg == \"trn_loss\":\n",
        "                metric_dict['trn_loss'] = trn_losses\n",
        "            if arg == \"trn_acc\":\n",
        "                metric_dict['trn_acc'] = trn_acc\n",
        "            if arg == \"subtrn_loss\":\n",
        "                metric_dict['subtrn_loss'] = subtrn_losses\n",
        "            if arg == \"subtrn_acc\":\n",
        "                metric_dict['subtrn_acc'] = subtrn_acc\n",
        "            if arg == \"time\":\n",
        "                metric_dict['time'] = timing\n",
        "\n",
        "        ckpt_state = {\n",
        "            'epoch': epoch + 1,\n",
        "            'state_dict': model.state_dict(),\n",
        "            'optimizer': optimizer.state_dict(),\n",
        "            'loss': criterion_nored,\n",
        "            'metrics': metric_dict\n",
        "        }\n",
        "\n",
        "        # save checkpoint\n",
        "        save_ckpt(ckpt_state, 'model.pt')\n",
        "        logger.info(\"Model checkpoint saved at epoch: {0:d}\".format(epoch + 1))\n",
        "\n",
        "\"\"\"\n",
        "################################################# Results Summary #################################################\n",
        "\"\"\"\n",
        "\n",
        "logger.info(\"{0:s} Selection Run---------------------------------\".format(selection_strategy))\n",
        "logger.info(\"Final SubsetTrn: {0:f}\".format(subtrn_loss))\n",
        "if \"val_loss\" in print_args:\n",
        "    if \"val_acc\" in print_args:\n",
        "        logger.info(\"Validation Loss: %.2f , Validation Accuracy: %.2f\", val_loss, val_acc[-1])\n",
        "    else:\n",
        "        logger.info(\"Validation Loss: %.2f\", val_loss)\n",
        "\n",
        "if \"tst_loss\" in print_args:\n",
        "    if \"tst_acc\" in print_args:\n",
        "        logger.info(\"Test Loss: %.2f, Test Accuracy: %.2f\", tst_loss, tst_acc[-1])\n",
        "    else:\n",
        "        logger.info(\"Test Data Loss: %f\", tst_loss)\n",
        "logger.info('---------------------------------------------------------------------')\n",
        "logger.info(selection_strategy)\n",
        "logger.info('---------------------------------------------------------------------')\n",
        "\n",
        "\"\"\"\n",
        "################################################# Final Results Logging #################################################\n",
        "\"\"\"\n",
        "\n",
        "if \"val_acc\" in print_args:\n",
        "    val_str = \"Validation Accuracy, \"\n",
        "    for val in val_acc:\n",
        "        val_str = val_str + \" , \" + str(val)\n",
        "    logger.info(val_str)\n",
        "\n",
        "if \"tst_acc\" in print_args:\n",
        "    tst_str = \"Test Accuracy, \"\n",
        "    for tst in tst_acc:\n",
        "        tst_str = tst_str + \" , \" + str(tst)\n",
        "    logger.info(tst_str)\n",
        "\n",
        "if \"time\" in print_args:\n",
        "    time_str = \"Time, \"\n",
        "    for t in timing:\n",
        "        time_str = time_str + \" , \" + str(t)\n",
        "    logger.info(timing)\n",
        "\n",
        "timing_array = np.array(timing)\n",
        "cum_timing = list(generate_cumulative_timing(timing_array))\n",
        "logger.info(\"Total time taken by %s = %.4f \", selection_strategy, cum_timing[-1])\n",
        "\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[10/19 19:28:35] __main__ INFO: Epoch: 1 , Validation Loss: 492.08510732650757 , Validation Accuracy: 0.2628 , Test Loss: 19.398529291152954 , Test Accuracy: 0.2813 , Timing: 7.384371995925903\n",
            "[10/19 19:28:50] __main__ INFO: Epoch: 2 , Validation Loss: 462.1935260295868 , Validation Accuracy: 0.3188 , Test Loss: 17.76013433933258 , Test Accuracy: 0.3481 , Timing: 7.13740348815918\n",
            "[10/19 19:29:05] __main__ INFO: Epoch: 3 , Validation Loss: 406.5821268558502 , Validation Accuracy: 0.392 , Test Loss: 15.817896723747253 , Test Accuracy: 0.4154 , Timing: 7.385777473449707\n",
            "[10/19 19:29:20] __main__ INFO: Epoch: 4 , Validation Loss: 384.15974444150925 , Validation Accuracy: 0.4114 , Test Loss: 15.2018541097641 , Test Accuracy: 0.4156 , Timing: 7.098953723907471\n",
            "[10/19 19:29:34] __main__ INFO: Epoch: 5 , Validation Loss: 380.91055899858475 , Validation Accuracy: 0.4478 , Test Loss: 14.655442833900452 , Test Accuracy: 0.4532 , Timing: 7.023574590682983\n",
            "[10/19 19:29:49] __main__ INFO: Epoch: 6 , Validation Loss: 436.0215688943863 , Validation Accuracy: 0.4044 , Test Loss: 17.362621426582336 , Test Accuracy: 0.405 , Timing: 7.1023850440979\n",
            "[10/19 19:30:04] __main__ INFO: Epoch: 7 , Validation Loss: 362.51565581560135 , Validation Accuracy: 0.4732 , Test Loss: 14.166994333267212 , Test Accuracy: 0.475 , Timing: 7.057530403137207\n",
            "[10/19 19:30:18] __main__ INFO: Epoch: 8 , Validation Loss: 342.44397324323654 , Validation Accuracy: 0.5078 , Test Loss: 13.583762049674988 , Test Accuracy: 0.5049 , Timing: 7.0412962436676025\n",
            "[10/19 19:30:33] __main__ INFO: Epoch: 9 , Validation Loss: 327.58211040496826 , Validation Accuracy: 0.5356 , Test Loss: 13.393548250198364 , Test Accuracy: 0.5255 , Timing: 6.9713523387908936\n",
            "[10/19 19:30:48] __main__ INFO: Epoch: 10 , Validation Loss: 300.23407459259033 , Validation Accuracy: 0.5684 , Test Loss: 11.836564660072327 , Test Accuracy: 0.5785 , Timing: 7.169913053512573\n",
            "[10/19 19:31:03] __main__ INFO: Epoch: 11 , Validation Loss: 292.16381800174713 , Validation Accuracy: 0.5798 , Test Loss: 11.631034016609192 , Test Accuracy: 0.5755 , Timing: 6.917028427124023\n",
            "[10/19 19:31:17] __main__ INFO: Epoch: 12 , Validation Loss: 310.27985870838165 , Validation Accuracy: 0.562 , Test Loss: 12.51718521118164 , Test Accuracy: 0.5573 , Timing: 7.078837633132935\n",
            "[10/19 19:31:32] __main__ INFO: Epoch: 13 , Validation Loss: 283.4440655708313 , Validation Accuracy: 0.608 , Test Loss: 11.246467351913452 , Test Accuracy: 0.6052 , Timing: 6.929779529571533\n",
            "[10/19 19:31:46] __main__ INFO: Epoch: 14 , Validation Loss: 271.4170537889004 , Validation Accuracy: 0.6132 , Test Loss: 10.317887663841248 , Test Accuracy: 0.6307 , Timing: 7.066625118255615\n",
            "[10/19 19:32:01] __main__ INFO: Epoch: 15 , Validation Loss: 278.9781844615936 , Validation Accuracy: 0.6034 , Test Loss: 10.922091960906982 , Test Accuracy: 0.622 , Timing: 7.0596184730529785\n",
            "[10/19 19:32:15] __main__ INFO: Epoch: 16 , Validation Loss: 260.86412489414215 , Validation Accuracy: 0.6306 , Test Loss: 10.446364402770996 , Test Accuracy: 0.6345 , Timing: 7.111894369125366\n",
            "[10/19 19:32:30] __main__ INFO: Epoch: 17 , Validation Loss: 256.6770333945751 , Validation Accuracy: 0.6288 , Test Loss: 10.358157277107239 , Test Accuracy: 0.625 , Timing: 7.117462396621704\n",
            "[10/19 19:32:44] __main__ INFO: Epoch: 18 , Validation Loss: 271.8172527551651 , Validation Accuracy: 0.6268 , Test Loss: 10.857910990715027 , Test Accuracy: 0.628 , Timing: 7.158030271530151\n",
            "[10/19 19:32:59] __main__ INFO: Epoch: 19 , Validation Loss: 259.3506673872471 , Validation Accuracy: 0.6416 , Test Loss: 10.608243823051453 , Test Accuracy: 0.6428 , Timing: 6.929781198501587\n",
            "[10/19 19:33:13] __main__ INFO: Epoch: 20 , Validation Loss: 243.66167396306992 , Validation Accuracy: 0.6556 , Test Loss: 9.844395101070404 , Test Accuracy: 0.6711 , Timing: 6.850125789642334\n",
            "[10/19 19:33:14] __main__ INFO: Model checkpoint saved at epoch: 20\n",
            "[10/19 19:33:54] __main__ INFO: Epoch: 21, GLISTER dataloader subset selection finished, takes 40.2558. \n",
            "[10/19 19:34:08] __main__ INFO: Epoch: 21 , Validation Loss: 284.6038470864296 , Validation Accuracy: 0.5928 , Test Loss: 10.75424838066101 , Test Accuracy: 0.6032 , Timing: 47.12204098701477\n",
            "[10/19 19:34:23] __main__ INFO: Epoch: 22 , Validation Loss: 271.86955869197845 , Validation Accuracy: 0.6048 , Test Loss: 10.521610379219055 , Test Accuracy: 0.6184 , Timing: 7.093116760253906\n",
            "[10/19 19:34:38] __main__ INFO: Epoch: 23 , Validation Loss: 272.87072491645813 , Validation Accuracy: 0.6072 , Test Loss: 10.02290415763855 , Test Accuracy: 0.6352 , Timing: 6.926366806030273\n",
            "[10/19 19:34:52] __main__ INFO: Epoch: 24 , Validation Loss: 297.9318799376488 , Validation Accuracy: 0.5688 , Test Loss: 11.506828904151917 , Test Accuracy: 0.5756 , Timing: 6.966162919998169\n",
            "[10/19 19:35:07] __main__ INFO: Epoch: 25 , Validation Loss: 275.13073551654816 , Validation Accuracy: 0.59 , Test Loss: 10.313598930835724 , Test Accuracy: 0.6229 , Timing: 6.91947340965271\n",
            "[10/19 19:35:21] __main__ INFO: Epoch: 26 , Validation Loss: 296.5569551885128 , Validation Accuracy: 0.5882 , Test Loss: 11.60485303401947 , Test Accuracy: 0.6049 , Timing: 7.087247133255005\n",
            "[10/19 19:35:36] __main__ INFO: Epoch: 27 , Validation Loss: 247.67118319869041 , Validation Accuracy: 0.6452 , Test Loss: 9.580496668815613 , Test Accuracy: 0.6531 , Timing: 7.267578840255737\n",
            "[10/19 19:35:51] __main__ INFO: Epoch: 28 , Validation Loss: 252.11789506673813 , Validation Accuracy: 0.6322 , Test Loss: 9.454959154129028 , Test Accuracy: 0.6516 , Timing: 6.9669859409332275\n",
            "[10/19 19:36:05] __main__ INFO: Epoch: 29 , Validation Loss: 300.38346642255783 , Validation Accuracy: 0.57 , Test Loss: 11.763608694076538 , Test Accuracy: 0.5948 , Timing: 7.2587890625\n",
            "[10/19 19:36:20] __main__ INFO: Epoch: 30 , Validation Loss: 272.5996436178684 , Validation Accuracy: 0.6018 , Test Loss: 11.079327583312988 , Test Accuracy: 0.6034 , Timing: 6.953402280807495\n",
            "[10/19 19:36:35] __main__ INFO: Epoch: 31 , Validation Loss: 240.07896229624748 , Validation Accuracy: 0.654 , Test Loss: 9.163584470748901 , Test Accuracy: 0.6675 , Timing: 6.945340633392334\n",
            "[10/19 19:36:49] __main__ INFO: Epoch: 32 , Validation Loss: 310.2935206890106 , Validation Accuracy: 0.5794 , Test Loss: 12.844932436943054 , Test Accuracy: 0.5785 , Timing: 6.923909664154053\n",
            "[10/19 19:37:04] __main__ INFO: Epoch: 33 , Validation Loss: 243.6154717206955 , Validation Accuracy: 0.656 , Test Loss: 9.858159303665161 , Test Accuracy: 0.6504 , Timing: 7.093969106674194\n",
            "[10/19 19:37:18] __main__ INFO: Epoch: 34 , Validation Loss: 268.7556462883949 , Validation Accuracy: 0.6234 , Test Loss: 10.292556345462799 , Test Accuracy: 0.6469 , Timing: 7.035810947418213\n",
            "[10/19 19:37:33] __main__ INFO: Epoch: 35 , Validation Loss: 244.35470741987228 , Validation Accuracy: 0.6522 , Test Loss: 9.936725795269012 , Test Accuracy: 0.6597 , Timing: 7.1172168254852295\n",
            "[10/19 19:37:47] __main__ INFO: Epoch: 36 , Validation Loss: 251.09840083122253 , Validation Accuracy: 0.647 , Test Loss: 10.004657447338104 , Test Accuracy: 0.6531 , Timing: 7.117445230484009\n",
            "[10/19 19:38:02] __main__ INFO: Epoch: 37 , Validation Loss: 265.9187017083168 , Validation Accuracy: 0.6232 , Test Loss: 10.548588693141937 , Test Accuracy: 0.6353 , Timing: 6.845056533813477\n",
            "[10/19 19:38:17] __main__ INFO: Epoch: 38 , Validation Loss: 265.443911164999 , Validation Accuracy: 0.6362 , Test Loss: 10.137325167655945 , Test Accuracy: 0.6583 , Timing: 7.071715354919434\n",
            "[10/19 19:38:31] __main__ INFO: Epoch: 39 , Validation Loss: 236.60534673929214 , Validation Accuracy: 0.6756 , Test Loss: 9.735805749893188 , Test Accuracy: 0.6797 , Timing: 6.901402473449707\n",
            "[10/19 19:38:46] __main__ INFO: Epoch: 40 , Validation Loss: 274.72547191381454 , Validation Accuracy: 0.6294 , Test Loss: 10.811769008636475 , Test Accuracy: 0.6531 , Timing: 7.242793083190918\n",
            "[10/19 19:38:47] __main__ INFO: Model checkpoint saved at epoch: 40\n",
            "[10/19 19:39:28] __main__ INFO: Epoch: 41, GLISTER dataloader subset selection finished, takes 41.5809. \n",
            "[10/19 19:39:43] __main__ INFO: Epoch: 41 , Validation Loss: 188.86843618750572 , Validation Accuracy: 0.7368 , Test Loss: 7.39064633846283 , Test Accuracy: 0.7412 , Timing: 48.59373164176941\n",
            "[10/19 19:39:57] __main__ INFO: Epoch: 42 , Validation Loss: 207.26006412506104 , Validation Accuracy: 0.7116 , Test Loss: 8.32771921157837 , Test Accuracy: 0.7262 , Timing: 7.098180532455444\n",
            "[10/19 19:40:12] __main__ INFO: Epoch: 43 , Validation Loss: 201.2406938225031 , Validation Accuracy: 0.7174 , Test Loss: 7.909608840942383 , Test Accuracy: 0.7308 , Timing: 6.861478805541992\n",
            "[10/19 19:40:27] __main__ INFO: Epoch: 44 , Validation Loss: 179.70300072431564 , Validation Accuracy: 0.7518 , Test Loss: 7.2497493624687195 , Test Accuracy: 0.7472 , Timing: 7.064647436141968\n",
            "[10/19 19:40:41] __main__ INFO: Epoch: 45 , Validation Loss: 187.23165959119797 , Validation Accuracy: 0.745 , Test Loss: 7.4533326625823975 , Test Accuracy: 0.7475 , Timing: 6.903841018676758\n",
            "[10/19 19:40:56] __main__ INFO: Epoch: 46 , Validation Loss: 210.82746270298958 , Validation Accuracy: 0.7094 , Test Loss: 8.305178225040436 , Test Accuracy: 0.7216 , Timing: 6.9639892578125\n",
            "[10/19 19:41:10] __main__ INFO: Epoch: 47 , Validation Loss: 201.3571225181222 , Validation Accuracy: 0.718 , Test Loss: 8.135901868343353 , Test Accuracy: 0.73 , Timing: 6.881036996841431\n",
            "[10/19 19:41:25] __main__ INFO: Epoch: 48 , Validation Loss: 184.00309674441814 , Validation Accuracy: 0.7496 , Test Loss: 7.438991606235504 , Test Accuracy: 0.7532 , Timing: 7.190493106842041\n",
            "[10/19 19:41:39] __main__ INFO: Epoch: 49 , Validation Loss: 190.48100917041302 , Validation Accuracy: 0.7414 , Test Loss: 7.553747117519379 , Test Accuracy: 0.749 , Timing: 6.985577583312988\n",
            "[10/19 19:41:54] __main__ INFO: Epoch: 50 , Validation Loss: 187.08595202863216 , Validation Accuracy: 0.7496 , Test Loss: 7.63663125038147 , Test Accuracy: 0.7571 , Timing: 7.114895820617676\n",
            "[10/19 19:42:08] __main__ INFO: Epoch: 51 , Validation Loss: 167.96041613817215 , Validation Accuracy: 0.768 , Test Loss: 7.3376041650772095 , Test Accuracy: 0.7637 , Timing: 6.987545490264893\n",
            "[10/19 19:42:23] __main__ INFO: Epoch: 52 , Validation Loss: 182.19049078971148 , Validation Accuracy: 0.749 , Test Loss: 7.216418981552124 , Test Accuracy: 0.7649 , Timing: 7.091436862945557\n",
            "[10/19 19:42:38] __main__ INFO: Epoch: 53 , Validation Loss: 224.08330710232258 , Validation Accuracy: 0.7172 , Test Loss: 9.122800529003143 , Test Accuracy: 0.7199 , Timing: 7.132680892944336\n",
            "[10/19 19:42:52] __main__ INFO: Epoch: 54 , Validation Loss: 198.7618813365698 , Validation Accuracy: 0.7314 , Test Loss: 7.623515427112579 , Test Accuracy: 0.7544 , Timing: 6.949977159500122\n",
            "[10/19 19:43:07] __main__ INFO: Epoch: 55 , Validation Loss: 211.52477636933327 , Validation Accuracy: 0.7372 , Test Loss: 8.383858025074005 , Test Accuracy: 0.7522 , Timing: 6.8768322467803955\n",
            "[10/19 19:43:21] __main__ INFO: Epoch: 56 , Validation Loss: 257.35280089080334 , Validation Accuracy: 0.6886 , Test Loss: 11.004509329795837 , Test Accuracy: 0.6874 , Timing: 6.940402507781982\n",
            "[10/19 19:43:35] __main__ INFO: Epoch: 57 , Validation Loss: 201.983826816082 , Validation Accuracy: 0.7372 , Test Loss: 8.498774528503418 , Test Accuracy: 0.7413 , Timing: 6.944841384887695\n",
            "[10/19 19:43:50] __main__ INFO: Epoch: 58 , Validation Loss: 205.46111766994 , Validation Accuracy: 0.7368 , Test Loss: 9.376294493675232 , Test Accuracy: 0.7325 , Timing: 6.949514150619507\n",
            "[10/19 19:44:04] __main__ INFO: Epoch: 59 , Validation Loss: 212.6609766483307 , Validation Accuracy: 0.7396 , Test Loss: 9.16330099105835 , Test Accuracy: 0.743 , Timing: 6.88653564453125\n",
            "[10/19 19:44:19] __main__ INFO: Epoch: 60 , Validation Loss: 234.84215535223484 , Validation Accuracy: 0.7062 , Test Loss: 9.995932519435883 , Test Accuracy: 0.7178 , Timing: 7.051782131195068\n",
            "[10/19 19:44:19] __main__ INFO: Model checkpoint saved at epoch: 60\n",
            "[10/19 19:44:59] __main__ INFO: Epoch: 61, GLISTER dataloader subset selection finished, takes 39.5948. \n",
            "[10/19 19:45:13] __main__ INFO: Epoch: 61 , Validation Loss: 179.02862253785133 , Validation Accuracy: 0.7412 , Test Loss: 7.144235014915466 , Test Accuracy: 0.7473 , Timing: 46.66924452781677\n",
            "[10/19 19:45:28] __main__ INFO: Epoch: 62 , Validation Loss: 165.76878425478935 , Validation Accuracy: 0.758 , Test Loss: 6.459660053253174 , Test Accuracy: 0.7728 , Timing: 7.36673641204834\n",
            "[10/19 19:45:43] __main__ INFO: Epoch: 63 , Validation Loss: 167.6530125886202 , Validation Accuracy: 0.7612 , Test Loss: 6.534153640270233 , Test Accuracy: 0.7701 , Timing: 6.965766668319702\n",
            "[10/19 19:45:57] __main__ INFO: Epoch: 64 , Validation Loss: 185.81999443471432 , Validation Accuracy: 0.7278 , Test Loss: 7.480808973312378 , Test Accuracy: 0.7355 , Timing: 7.211893796920776\n",
            "[10/19 19:46:12] __main__ INFO: Epoch: 65 , Validation Loss: 182.15535517036915 , Validation Accuracy: 0.7464 , Test Loss: 7.591017186641693 , Test Accuracy: 0.7412 , Timing: 7.115354061126709\n",
            "[10/19 19:46:27] __main__ INFO: Epoch: 66 , Validation Loss: 158.3308206796646 , Validation Accuracy: 0.7786 , Test Loss: 6.617851078510284 , Test Accuracy: 0.7724 , Timing: 7.102010726928711\n",
            "[10/19 19:46:42] __main__ INFO: Epoch: 67 , Validation Loss: 167.49169358611107 , Validation Accuracy: 0.765 , Test Loss: 6.742274165153503 , Test Accuracy: 0.7713 , Timing: 7.162262916564941\n",
            "[10/19 19:46:57] __main__ INFO: Epoch: 68 , Validation Loss: 169.41967387497425 , Validation Accuracy: 0.7684 , Test Loss: 6.857652008533478 , Test Accuracy: 0.7776 , Timing: 6.9851953983306885\n",
            "[10/19 19:47:11] __main__ INFO: Epoch: 69 , Validation Loss: 176.22394980490208 , Validation Accuracy: 0.758 , Test Loss: 7.185362994670868 , Test Accuracy: 0.7678 , Timing: 7.321051836013794\n",
            "[10/19 19:47:26] __main__ INFO: Epoch: 70 , Validation Loss: 187.3963697552681 , Validation Accuracy: 0.7668 , Test Loss: 8.18281865119934 , Test Accuracy: 0.7549 , Timing: 7.109907150268555\n",
            "[10/19 19:47:41] __main__ INFO: Epoch: 71 , Validation Loss: 197.42998799681664 , Validation Accuracy: 0.7332 , Test Loss: 8.249788463115692 , Test Accuracy: 0.74 , Timing: 7.209913492202759\n",
            "[10/19 19:47:56] __main__ INFO: Epoch: 72 , Validation Loss: 182.81028674542904 , Validation Accuracy: 0.7606 , Test Loss: 7.503394544124603 , Test Accuracy: 0.7713 , Timing: 7.036122560501099\n",
            "[10/19 19:48:11] __main__ INFO: Epoch: 73 , Validation Loss: 170.60488981753588 , Validation Accuracy: 0.78 , Test Loss: 7.181918025016785 , Test Accuracy: 0.7795 , Timing: 7.027761459350586\n",
            "[10/19 19:48:26] __main__ INFO: Epoch: 74 , Validation Loss: 206.9090990126133 , Validation Accuracy: 0.7442 , Test Loss: 8.676189005374908 , Test Accuracy: 0.7487 , Timing: 7.225622653961182\n",
            "[10/19 19:48:41] __main__ INFO: Epoch: 75 , Validation Loss: 212.68774999678135 , Validation Accuracy: 0.7498 , Test Loss: 8.751668810844421 , Test Accuracy: 0.7541 , Timing: 7.066665887832642\n",
            "[10/19 19:48:55] __main__ INFO: Epoch: 76 , Validation Loss: 198.07763665914536 , Validation Accuracy: 0.759 , Test Loss: 7.74499648809433 , Test Accuracy: 0.7752 , Timing: 7.120447397232056\n",
            "[10/19 19:49:10] __main__ INFO: Epoch: 77 , Validation Loss: 217.63727498054504 , Validation Accuracy: 0.746 , Test Loss: 8.815999865531921 , Test Accuracy: 0.7577 , Timing: 6.948836326599121\n",
            "[10/19 19:49:24] __main__ INFO: Epoch: 78 , Validation Loss: 192.4555062018335 , Validation Accuracy: 0.7644 , Test Loss: 8.400214195251465 , Test Accuracy: 0.7642 , Timing: 6.9664671421051025\n",
            "[10/19 19:49:39] __main__ INFO: Epoch: 79 , Validation Loss: 199.3766826838255 , Validation Accuracy: 0.759 , Test Loss: 8.59498655796051 , Test Accuracy: 0.7564 , Timing: 7.128437042236328\n",
            "[10/19 19:49:54] __main__ INFO: Epoch: 80 , Validation Loss: 214.31042162328959 , Validation Accuracy: 0.747 , Test Loss: 8.965057849884033 , Test Accuracy: 0.7517 , Timing: 7.070586919784546\n",
            "[10/19 19:49:54] __main__ INFO: Model checkpoint saved at epoch: 80\n",
            "[10/19 19:50:35] __main__ INFO: Epoch: 81, GLISTER dataloader subset selection finished, takes 40.4166. \n",
            "[10/19 19:50:49] __main__ INFO: Epoch: 81 , Validation Loss: 155.27281421422958 , Validation Accuracy: 0.7768 , Test Loss: 6.2469770312309265 , Test Accuracy: 0.7741 , Timing: 47.383991718292236\n",
            "[10/19 19:51:04] __main__ INFO: Epoch: 82 , Validation Loss: 145.02184934914112 , Validation Accuracy: 0.795 , Test Loss: 5.638074278831482 , Test Accuracy: 0.8051 , Timing: 7.005133390426636\n",
            "[10/19 19:51:18] __main__ INFO: Epoch: 83 , Validation Loss: 147.23564840853214 , Validation Accuracy: 0.7894 , Test Loss: 5.881366670131683 , Test Accuracy: 0.803 , Timing: 6.933480978012085\n",
            "[10/19 19:51:33] __main__ INFO: Epoch: 84 , Validation Loss: 136.75849364697933 , Validation Accuracy: 0.807 , Test Loss: 5.438498079776764 , Test Accuracy: 0.8154 , Timing: 7.170032262802124\n",
            "[10/19 19:51:48] __main__ INFO: Epoch: 85 , Validation Loss: 148.93780145049095 , Validation Accuracy: 0.7956 , Test Loss: 5.884674906730652 , Test Accuracy: 0.8087 , Timing: 7.016028165817261\n",
            "[10/19 19:52:02] __main__ INFO: Epoch: 86 , Validation Loss: 187.7990271449089 , Validation Accuracy: 0.7566 , Test Loss: 7.376343250274658 , Test Accuracy: 0.7639 , Timing: 7.012562036514282\n",
            "[10/19 19:52:17] __main__ INFO: Epoch: 87 , Validation Loss: 187.31119230389595 , Validation Accuracy: 0.7536 , Test Loss: 7.568432033061981 , Test Accuracy: 0.7586 , Timing: 7.056127548217773\n",
            "[10/19 19:52:31] __main__ INFO: Epoch: 88 , Validation Loss: 151.83437184989452 , Validation Accuracy: 0.7998 , Test Loss: 6.239842057228088 , Test Accuracy: 0.8067 , Timing: 7.047255516052246\n",
            "[10/19 19:52:46] __main__ INFO: Epoch: 89 , Validation Loss: 154.89519707858562 , Validation Accuracy: 0.794 , Test Loss: 6.622120797634125 , Test Accuracy: 0.7906 , Timing: 6.959128141403198\n",
            "[10/19 19:53:01] __main__ INFO: Epoch: 90 , Validation Loss: 168.011096842587 , Validation Accuracy: 0.7732 , Test Loss: 6.863397181034088 , Test Accuracy: 0.7903 , Timing: 7.0726964473724365\n",
            "[10/19 19:53:15] __main__ INFO: Epoch: 91 , Validation Loss: 155.90827155858278 , Validation Accuracy: 0.7982 , Test Loss: 6.443131387233734 , Test Accuracy: 0.8084 , Timing: 6.909491777420044\n",
            "[10/19 19:53:30] __main__ INFO: Epoch: 92 , Validation Loss: 175.69880571961403 , Validation Accuracy: 0.7664 , Test Loss: 7.3641955852508545 , Test Accuracy: 0.7802 , Timing: 6.878342390060425\n",
            "[10/19 19:53:45] __main__ INFO: Epoch: 93 , Validation Loss: 197.70272456854582 , Validation Accuracy: 0.7634 , Test Loss: 8.287022292613983 , Test Accuracy: 0.7702 , Timing: 7.092571020126343\n",
            "[10/19 19:53:59] __main__ INFO: Epoch: 94 , Validation Loss: 178.22721043229103 , Validation Accuracy: 0.7878 , Test Loss: 7.180034101009369 , Test Accuracy: 0.7977 , Timing: 7.082178115844727\n",
            "[10/19 19:54:14] __main__ INFO: Epoch: 95 , Validation Loss: 181.37198916822672 , Validation Accuracy: 0.781 , Test Loss: 7.394755244255066 , Test Accuracy: 0.7962 , Timing: 7.038629531860352\n",
            "[10/19 19:54:28] __main__ INFO: Epoch: 96 , Validation Loss: 197.09386394172907 , Validation Accuracy: 0.763 , Test Loss: 8.02722829580307 , Test Accuracy: 0.7825 , Timing: 7.092298746109009\n",
            "[10/19 19:54:43] __main__ INFO: Epoch: 97 , Validation Loss: 170.14967965334654 , Validation Accuracy: 0.7906 , Test Loss: 6.962323844432831 , Test Accuracy: 0.8052 , Timing: 7.085120916366577\n",
            "[10/19 19:54:58] __main__ INFO: Epoch: 98 , Validation Loss: 184.70851833373308 , Validation Accuracy: 0.7844 , Test Loss: 7.546197950839996 , Test Accuracy: 0.7947 , Timing: 7.2416465282440186\n",
            "[10/19 19:55:13] __main__ INFO: Epoch: 99 , Validation Loss: 201.19679912924767 , Validation Accuracy: 0.7824 , Test Loss: 8.561929702758789 , Test Accuracy: 0.7788 , Timing: 7.210483551025391\n",
            "[10/19 19:55:27] __main__ INFO: Epoch: 100 , Validation Loss: 191.2862502336502 , Validation Accuracy: 0.7908 , Test Loss: 7.929006040096283 , Test Accuracy: 0.7939 , Timing: 6.930764198303223\n",
            "[10/19 19:55:28] __main__ INFO: Model checkpoint saved at epoch: 100\n",
            "[10/19 19:56:08] __main__ INFO: Epoch: 101, GLISTER dataloader subset selection finished, takes 40.0751. \n",
            "[10/19 19:56:22] __main__ INFO: Epoch: 101 , Validation Loss: 153.24161805212498 , Validation Accuracy: 0.7754 , Test Loss: 5.625534236431122 , Test Accuracy: 0.803 , Timing: 47.07431244850159\n",
            "[10/19 19:56:37] __main__ INFO: Epoch: 102 , Validation Loss: 131.22202676534653 , Validation Accuracy: 0.811 , Test Loss: 5.128549784421921 , Test Accuracy: 0.8146 , Timing: 7.126592397689819\n",
            "[10/19 19:56:52] __main__ INFO: Epoch: 103 , Validation Loss: 140.0789419710636 , Validation Accuracy: 0.803 , Test Loss: 5.4606218338012695 , Test Accuracy: 0.8085 , Timing: 6.9356160163879395\n",
            "[10/19 19:57:07] __main__ INFO: Epoch: 104 , Validation Loss: 125.13339870423079 , Validation Accuracy: 0.8292 , Test Loss: 5.036296159029007 , Test Accuracy: 0.8301 , Timing: 7.002356052398682\n",
            "[10/19 19:57:21] __main__ INFO: Epoch: 105 , Validation Loss: 162.137497022748 , Validation Accuracy: 0.7842 , Test Loss: 6.543257474899292 , Test Accuracy: 0.7944 , Timing: 6.95437479019165\n",
            "[10/19 19:57:36] __main__ INFO: Epoch: 106 , Validation Loss: 164.47448885440826 , Validation Accuracy: 0.7792 , Test Loss: 6.387215435504913 , Test Accuracy: 0.7964 , Timing: 7.113501071929932\n",
            "[10/19 19:57:50] __main__ INFO: Epoch: 107 , Validation Loss: 159.53937081247568 , Validation Accuracy: 0.7928 , Test Loss: 6.555147171020508 , Test Accuracy: 0.7964 , Timing: 6.998284578323364\n",
            "[10/19 19:58:05] __main__ INFO: Epoch: 108 , Validation Loss: 161.3360228985548 , Validation Accuracy: 0.7974 , Test Loss: 6.745216012001038 , Test Accuracy: 0.793 , Timing: 6.990890264511108\n",
            "[10/19 19:58:19] __main__ INFO: Epoch: 109 , Validation Loss: 162.17839232087135 , Validation Accuracy: 0.7962 , Test Loss: 6.368681848049164 , Test Accuracy: 0.8066 , Timing: 7.118534803390503\n",
            "[10/19 19:58:34] __main__ INFO: Epoch: 110 , Validation Loss: 153.03522443771362 , Validation Accuracy: 0.8044 , Test Loss: 5.959374666213989 , Test Accuracy: 0.8169 , Timing: 7.213549852371216\n",
            "[10/19 19:58:49] __main__ INFO: Epoch: 111 , Validation Loss: 158.4620806761086 , Validation Accuracy: 0.8056 , Test Loss: 6.3487162590026855 , Test Accuracy: 0.8087 , Timing: 6.946784973144531\n",
            "[10/19 19:59:03] __main__ INFO: Epoch: 112 , Validation Loss: 154.29458609968424 , Validation Accuracy: 0.8054 , Test Loss: 6.4338299036026 , Test Accuracy: 0.8116 , Timing: 6.991189479827881\n",
            "[10/19 19:59:18] __main__ INFO: Epoch: 113 , Validation Loss: 215.97931910306215 , Validation Accuracy: 0.7576 , Test Loss: 8.652061581611633 , Test Accuracy: 0.7642 , Timing: 7.2418434619903564\n",
            "[10/19 19:59:33] __main__ INFO: Epoch: 114 , Validation Loss: 199.44474786147475 , Validation Accuracy: 0.7738 , Test Loss: 8.074406087398529 , Test Accuracy: 0.7852 , Timing: 7.060486316680908\n",
            "[10/19 19:59:48] __main__ INFO: Epoch: 115 , Validation Loss: 175.12682769447565 , Validation Accuracy: 0.788 , Test Loss: 7.565969526767731 , Test Accuracy: 0.787 , Timing: 7.303961277008057\n",
            "[10/19 20:00:02] __main__ INFO: Epoch: 116 , Validation Loss: 170.91896366328 , Validation Accuracy: 0.7928 , Test Loss: 6.800839364528656 , Test Accuracy: 0.809 , Timing: 7.118740558624268\n",
            "[10/19 20:00:16] __main__ INFO: Epoch: 117 , Validation Loss: 169.46627574414015 , Validation Accuracy: 0.8044 , Test Loss: 6.943419992923737 , Test Accuracy: 0.8102 , Timing: 6.88494610786438\n",
            "[10/19 20:00:31] __main__ INFO: Epoch: 118 , Validation Loss: 164.5971268452704 , Validation Accuracy: 0.8046 , Test Loss: 6.88622909784317 , Test Accuracy: 0.8158 , Timing: 7.18582820892334\n",
            "[10/19 20:00:46] __main__ INFO: Epoch: 119 , Validation Loss: 180.29208890348673 , Validation Accuracy: 0.7922 , Test Loss: 7.408359467983246 , Test Accuracy: 0.799 , Timing: 6.960339069366455\n",
            "[10/19 20:01:00] __main__ INFO: Epoch: 120 , Validation Loss: 175.23802842572331 , Validation Accuracy: 0.801 , Test Loss: 7.087670147418976 , Test Accuracy: 0.8116 , Timing: 7.07940411567688\n",
            "[10/19 20:01:00] __main__ INFO: Model checkpoint saved at epoch: 120\n",
            "[10/19 20:01:41] __main__ INFO: Epoch: 121, GLISTER dataloader subset selection finished, takes 40.7762. \n",
            "[10/19 20:01:56] __main__ INFO: Epoch: 121 , Validation Loss: 146.99823847413063 , Validation Accuracy: 0.7968 , Test Loss: 5.704290688037872 , Test Accuracy: 0.8085 , Timing: 47.99230623245239\n",
            "[10/19 20:02:11] __main__ INFO: Epoch: 122 , Validation Loss: 117.68176108598709 , Validation Accuracy: 0.8332 , Test Loss: 4.5916624665260315 , Test Accuracy: 0.8407 , Timing: 7.161189317703247\n",
            "[10/19 20:02:26] __main__ INFO: Epoch: 123 , Validation Loss: 124.96651737391949 , Validation Accuracy: 0.8302 , Test Loss: 4.88205686211586 , Test Accuracy: 0.8374 , Timing: 7.122603893280029\n",
            "[10/19 20:02:41] __main__ INFO: Epoch: 124 , Validation Loss: 134.5686968229711 , Validation Accuracy: 0.8204 , Test Loss: 5.347210198640823 , Test Accuracy: 0.8285 , Timing: 6.996565103530884\n",
            "[10/19 20:02:55] __main__ INFO: Epoch: 125 , Validation Loss: 134.354761749506 , Validation Accuracy: 0.8232 , Test Loss: 5.296544253826141 , Test Accuracy: 0.8318 , Timing: 6.889234781265259\n",
            "[10/19 20:03:10] __main__ INFO: Epoch: 126 , Validation Loss: 138.27847772277892 , Validation Accuracy: 0.826 , Test Loss: 5.64277520775795 , Test Accuracy: 0.8263 , Timing: 7.000261545181274\n",
            "[10/19 20:03:24] __main__ INFO: Epoch: 127 , Validation Loss: 130.88200693577528 , Validation Accuracy: 0.8312 , Test Loss: 5.4492117166519165 , Test Accuracy: 0.8402 , Timing: 6.826524019241333\n",
            "[10/19 20:03:39] __main__ INFO: Epoch: 128 , Validation Loss: 136.59834815189242 , Validation Accuracy: 0.8274 , Test Loss: 5.669817864894867 , Test Accuracy: 0.8297 , Timing: 7.105190277099609\n",
            "[10/19 20:03:53] __main__ INFO: Epoch: 129 , Validation Loss: 129.0943593941629 , Validation Accuracy: 0.8362 , Test Loss: 5.368352979421616 , Test Accuracy: 0.8381 , Timing: 6.99227499961853\n",
            "[10/19 20:04:08] __main__ INFO: Epoch: 130 , Validation Loss: 135.92383495159447 , Validation Accuracy: 0.829 , Test Loss: 5.668733716011047 , Test Accuracy: 0.8363 , Timing: 7.103026628494263\n",
            "[10/19 20:04:23] __main__ INFO: Epoch: 131 , Validation Loss: 139.6357704102993 , Validation Accuracy: 0.8382 , Test Loss: 5.732183277606964 , Test Accuracy: 0.8397 , Timing: 6.978654146194458\n",
            "[10/19 20:04:37] __main__ INFO: Epoch: 132 , Validation Loss: 143.03862323611975 , Validation Accuracy: 0.8256 , Test Loss: 5.754538178443909 , Test Accuracy: 0.836 , Timing: 7.183685302734375\n",
            "[10/19 20:04:52] __main__ INFO: Epoch: 133 , Validation Loss: 164.38674141466618 , Validation Accuracy: 0.8142 , Test Loss: 6.468526303768158 , Test Accuracy: 0.8146 , Timing: 7.220989942550659\n",
            "[10/19 20:05:07] __main__ INFO: Epoch: 134 , Validation Loss: 158.38328400999308 , Validation Accuracy: 0.8158 , Test Loss: 6.492070913314819 , Test Accuracy: 0.8242 , Timing: 6.921018838882446\n",
            "[10/19 20:05:22] __main__ INFO: Epoch: 135 , Validation Loss: 141.43499970063567 , Validation Accuracy: 0.833 , Test Loss: 6.108040273189545 , Test Accuracy: 0.833 , Timing: 7.125321865081787\n",
            "[10/19 20:05:36] __main__ INFO: Epoch: 136 , Validation Loss: 163.42201919853687 , Validation Accuracy: 0.8182 , Test Loss: 6.832237124443054 , Test Accuracy: 0.8212 , Timing: 6.885319709777832\n",
            "[10/19 20:05:51] __main__ INFO: Epoch: 137 , Validation Loss: 155.7346240375191 , Validation Accuracy: 0.8232 , Test Loss: 6.498092770576477 , Test Accuracy: 0.8295 , Timing: 7.168548107147217\n",
            "[10/19 20:06:06] __main__ INFO: Epoch: 138 , Validation Loss: 146.2319353017956 , Validation Accuracy: 0.8316 , Test Loss: 5.7573694586753845 , Test Accuracy: 0.8444 , Timing: 7.019357681274414\n",
            "[10/19 20:06:20] __main__ INFO: Epoch: 139 , Validation Loss: 153.8008204791695 , Validation Accuracy: 0.8306 , Test Loss: 6.553818643093109 , Test Accuracy: 0.8307 , Timing: 7.0251429080963135\n",
            "[10/19 20:06:35] __main__ INFO: Epoch: 140 , Validation Loss: 169.6579960435629 , Validation Accuracy: 0.8158 , Test Loss: 6.6945300698280334 , Test Accuracy: 0.8295 , Timing: 7.130890846252441\n",
            "[10/19 20:06:35] __main__ INFO: Model checkpoint saved at epoch: 140\n",
            "[10/19 20:07:17] __main__ INFO: Epoch: 141, GLISTER dataloader subset selection finished, takes 41.7080. \n",
            "[10/19 20:07:31] __main__ INFO: Epoch: 141 , Validation Loss: 131.62064509093761 , Validation Accuracy: 0.8102 , Test Loss: 5.05686891078949 , Test Accuracy: 0.8209 , Timing: 48.66609525680542\n"
          ]
        }
      ]
    }
  ]
}