{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bc1280f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import os, time, copy\n",
    "import torch\n",
    "import torchvision as tv\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "from torch.utils.data import TensorDataset\n",
    "from torchvision.models import resnet50\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "from cifar_utils import get_model, get_data, show_img"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0709702",
   "metadata": {},
   "source": [
    "The easiest way to access a GPU for model training is to request one using srun or sbatch and run `python cifar_utils.py` (make sure to update the config in that file as appropriate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "73ea3e13",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model parameters\n",
    "config = {\n",
    "        'num_classes' : 100,\n",
    "        'batch_size' : 128,\n",
    "        'lr' : 0.0001,\n",
    "        'feature_extract' : False, # If False, fine tune all layers. If True, fine tune last layer only\n",
    "        'num_epochs' : 30,\n",
    "        'device' : 'cpu',\n",
    "        'frac_val' : 0.7, # CHANGED FROM 0.3\n",
    "        'model_filename' : 'best-cifar100-model-fracval=0.7', # CHANGED FROM no suffix\n",
    "        'num_workers' : 4,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a7d0910f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = get_model(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6b942258",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "=== With frac_val = 0.7 === \n",
    "Epoch 29/29\n",
    "----------\n",
    "train Loss: 0.0295 Acc: 0.9938\n",
    "val Loss: 2.5458 Acc: 0.5470\n",
    "\n",
    "Training complete in 4m 43s\n",
    "Best val Acc: 0.546976\n",
    "'''\n",
    "\n",
    "'''\n",
    "=== With frac_val = 0.5 ===  <-- THIS IS WHAT WE USE\n",
    "Epoch 29/29\n",
    "----------\n",
    "train Loss: 0.0315 Acc: 0.9913\n",
    "val Loss: 2.3326 Acc: 0.5974\n",
    "\n",
    "Training complete in 6m 44s\n",
    "Best val Acc: 0.597367\n",
    "'''\n",
    "\n",
    "\n",
    "'''\n",
    "=== With frac_val = 0.3 === \n",
    "Epoch 0/0\n",
    "----------\n",
    "train Loss: 3.3722 Acc: 0.2335\n",
    "val Loss: 2.2128 Acc: 0.4491\n",
    "\n",
    "Training complete in 3m 27s\n",
    "Best val Acc: 0.449056\n",
    "\n",
    "\n",
    "Epoch 29/29\n",
    "----------\n",
    "train Loss: 0.0373 Acc: 0.9889\n",
    "val Loss: 2.4569 Acc: 0.6305\n",
    "\n",
    "Training complete in 64m 27s\n",
    "Best val Acc: 0.630556\n",
    "'''\n",
    "None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "0a976e04",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_imgs = np.load('./.cache/' + config['model_filename'] + f'-valdata_frac={config[\"frac_val\"]}.npy')\n",
    "val_labels = np.load('./.cache/' + config['model_filename'] + f'-vallabels_frac={config[\"frac_val\"]}.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "9d5c6e00",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZhElEQVR4nO3cS69kh3Xd8V3v562q++7u290kRYmMJJuRIyGxnQQG7BhQBkkMj/wN/N0yyMgDw4ATWEAUy6QompTosBk2+8Hb3fdZVbfe7wwEbHjmtQAb9uD/G+/eOH0ete4ZnFXY7Xa7AAAgIor/3AcAAPiXg1AAACRCAQCQCAUAQCIUAACJUAAAJEIBAJAIBQBAKquD3/vgx9bi/XZDnm0bsxER6yjIs/VK1dpda7bl2R/93n+0dq+3G3n26nJo7d4YuyMi3rw8l2dns7G1u9HQz2G717V2t1p1eXbY71u7G019d0TEfD7VZ6dzb/dsJs+Ox/pxRETs7x/Ks72DA2v343e/Jc8eHHm7e92WNf+b//p9fffenrX7bjCQZzsd7x7/6ulLebZcqVi7f/8Hj//BGd4UAACJUAAAJEIBAJAIBQBAIhQAAIlQAAAkQgEAkAgFAEAiFAAAiVAAACRCAQCQ5O6jrdmts1iv5NntcG3t3u307qO7nbU6tqH35Vxd/A9r9/6+3oEyXXsHXm+Y/VEr/ZwvVktrd7WmdwjNR3fW7sV0JM+WzD95dsY9GxGxXuvncDLy+qPm04k8WyyWrN2jgX4Oo+B16+z19f9nf6j3O0VEbML7nTh6dCbPDkfetZ/e6d1k66X3LJ8/+1qeXW68c0L3EQDAQigAABKhAABIhAIAIBEKAIBEKAAAEqEAAEiEAgAgEQoAgEQoAACSXHNR2G2txTvj8+v5zqvQKBpZtjVzb7LQ/5/9/o21u3+rz5tfxke97tVclMpGNYJ5LOWyXo0w1RtLIiKi2ah6/8AwvLm15ktl/eCv31x6B2M8E7Wad+0LBb3SobHXsnZfXV7Isxdv9NmIiFLFu/btdlOeHfW9upXe/oE8+/DxA2v3J3/9iTw7uhtYu//0T378D87wpgAASIQCACARCgCARCgAABKhAABIhAIAIBEKAIBEKAAAEqEAAEiEAgAgEQoAgCR3H9Vr3uJqTS/Mma+8cp3CVu8nqle9vpThaCzPbld6v1NExGQ704dLRjdRRJQKXr7Pp/qxl6vexXeOvLnn9fb0r67k2cFgZO1ezKbWfK2mn/PxaG7tbjTr8mzZvFc2G/35GVx5/USjO71D6OKNfi0jIg5Pjq35j//PR/Ls9avX3rE8uC/Pfv3M6z768rNfybMbo2NOxZsCACARCgCARCgAABKhAABIhAIAIBEKAIBEKAAAEqEAAEiEAgAgEQoAgCTXXBTLXn6sjAqIWsWrotis9c/0d1vzM/CdPl8seuek5NQRlOVLExERnc6eNd/e0+dbLa+KYrPbyLOvnj21dt9eXMuzs7l37RfLpTVfKuv1LN1Oy9rdNJpF2k3vPtys9PnZdGjtHvT78ux0atS+RER96D0To75euTEa6scdETGZ6LtfPvXu8f6Nfo+32l1rt4I3BQBAIhQAAIlQAAAkQgEAkAgFAEAiFAAAiVAAACRCAQCQCAUAQCIUAACJUAAAJLlMpFr3unXWa713pmX29my3K3l2OfP6bNodvZ9ovZhbu6u1ijx7ev+Btfvs4X1r/vrytTz76puX1u6LK727ZWP2DVWLBXl2udI7siIi1puFNV+v6Z1Q7abX79XQb5U4bBvDEVGrNuXZ25HXN/Tqixfy7Gqpd2RFRCxq3rHMp/r1XBm/V7+e138nZjPvvpov9E6ozcp7fhS8KQAAEqEAAEiEAgAgEQoAgEQoAAASoQAASIQCACARCgCARCgAABKhAABI8nfjZ4/fthZPRiN5ttHwKgAKBT3LxtOJtfvo8T15tl3zPtMvLPry7Da86oLB7YU1//RrvY7g7la/lhER88VUnm2Y1QWLhV4ZsN5Zq6O717LmTw668uxeR6/EiIioVPR7vFb1np9WXb+3CqW2tbtZq8mz0/AqGhZT/b6KiJhP9bqI9ca7WYyfoFjO9FqeiIitUbnh3uMK3hQAAIlQAAAkQgEAkAgFAEAiFAAAiVAAACRCAQCQCAUAQCIUAACJUAAAJEIBAJDk4pnuQcdavJzqfTm9e/et3cO+3iG0fH1u7b7/1vfk2cdHJWv38Bv9nHz02VfW7mL9wJp/+z39/zkc6Oc7IuLq/Gt5trD2+m8mG71v6vTIOyfHBz1r/mR/T57t7tWt3bWq3iHUaDSt3aWKXphTq3gdXIcnPXn2ejK2dr++uLXm7/p38uzozutVmk70XqXRxOwOm+r9Xpu1fp+oeFMAACRCAQCQCAUAQCIUAACJUAAAJEIBAJAIBQBAIhQAAIlQAAAkQgEAkOSai5vLS2vxYjqXZ99++31rd/U3u/Ls1YsX1u7Dh2/Ls+X5K2v3J3/1l/Ls1Dh/ERH3TvRzEhFRNxo6Dh4cWbtrS71eoBR6bUVERLvVkmdPjw+t3bWaV+nQbugVA/tGJUZERKkkP5qxiYK1O4r6xS95TS7RKev/4MCsuXjr5NSa3xT1e2s282ouJkbNxcWlVxNza9Rz3FwPrd0K3hQAAIlQAAAkQgEAkAgFAEAiFAAAiVAAACRCAQCQCAUAQCIUAACJUAAAJEIBAJDkgpX7Z9+xFu//xr48++idd6zdjY6++/T+mbW72enJs1/94sbafXEzkWd3a68T6OrypTUf3aY8+t3vv2etPqzo53yzWlm7W+2GPFsu6v1BERH1ttdP5HQIrUp6T1JERH+k9wItN2trd5T187Ld7azV1Yp+TjpNvccqwuvrioioGP/PTbNj7V519fv2eM/rJZstF/rsYmntVvCmAABIhAIAIBEKAIBEKAAAEqEAAEiEAgAgEQoAgEQoAAASoQAASIQCACARCgCAJJeDvP2vvm0tPjo8lmfvPdRnIyJWi7k+O+pbu0uLrTx7/+zE2v2jP/hP8uznH35k7f7iyZfWfPfbj+TZYf/S2r1e6l085XLB2t1ot+XZalXvd4qI2Jp/I5XKFXl2F16H0Hard19Vq1Vrd7mm9zCtl14H19LoshpOZ9buWtG7V3plvcuqGF6xUqVsdDx1vQ6umOi9V7Wa2Xsl4E0BAJAIBQBAIhQAAIlQAAAkQgEAkAgFAEAiFAAAiVAAACRCAQCQCAUAQJK/v37r/XetxRXjM/3y2quiKKz1z8CPm3olRkREt9eSZwc7rwLg3/zgPXl2eHFu7X75/GtrvmJUNJRqDWv3bKnXF6zWXnXBfKMfd7Xm1VzUjHMSETExahpKO70+JSKi2dDP+Wbj7a4b/8/xyqznMB6JUtn7m7TV0itOIiLKVb1eotPRKzEiIo7u6dU8Y6O2IiLikw8/k2fnM71WRMWbAgAgEQoAgEQoAAASoQAASIQCACARCgCARCgAABKhAABIhAIAIBEKAIBEKAAAklwO0m7UrcXDbz6XZ5dvJtbuotE5tFtOrd3t0lKebZS87qPvPOzJs80//G1rd7dZteYnS/3vgUpr39rda+ldPMu119tT7R7Js8Vm19od4fX8HDT0bqVW23t+yqF3Qjl9XRERvRO9tydK3t+NT5+8kGdnc+/Z/P5337LmC0a3UqftdR/dXF7Js5///Atr93ygd8GVzX4vBW8KAIBEKAAAEqEAAEiEAgAgEQoAgEQoAAASoQAASIQCACARCgCARCgAAJJcc/Hzv/jv1uJv/u5v5dkf/NCrdNht1/Ls65fPrd39N2/k2bN3DqzdB4f35dl/++8+sHYfHnif6V8Yn+nXq15Fw1fPL+XZV/25tXu/05Fnyw19NiKiWfeqQv79f/iePNvteXUEi4Vet1JvNazd9aZ+PatmzcVkqFfWPHkytHZvVt69UinX5NlPPtZ/ryIiPvrJT+XZyeDO2r1/fE8f3nrVLAreFAAAiVAAACRCAQCQCAUAQCIUAACJUAAAJEIBAJAIBQBAIhQAAIlQAAAkQgEAkOTuo4uLa2vxh//7Z/Jso1yxdh+d6t0g49HM2v2Lbz6SZ79+3rV2/87v/FCeffDA6D+JiA8++LY1Pxoey7ONhtd9dHjQkmf/7H9+bu0+2Nf7bLr1hbX7ybnXxXN5fasfS6dg7V4t9ft2Mu5bu6d3ej/RZKLPRkR8/NGH8uwXX55bu58+fWLND/oDeXbc9/qJKputPHt8tG/tLlX0Dq5//OYj3hQAAH8PoQAASIQCACARCgCARCgAABKhAABIhAIAIBEKAIBEKAAAEqEAAEhyzcW9937DWryNP5NnL169tnbv7ek1CvrH6L9WqumVDv1rrwLgb37+hTx72Gtbu/dPHlrzxYL+98B2u7J2d1t6pUOv4u0urkby7GhjrY6r1y+t+b/9yY08++brE2t3/2Ygz04mc2v31ZVei7FcrK3dL84v5dnL4dja/fy5dyx3I/1eeetMr32JiHh4ptfQLM2qkN1KrzjZVPVKDBVvCgCARCgAABKhAABIhAIAIBEKAIBEKAAAEqEAAEiEAgAgEQoAgEQoAAASoQAASHL30UGzZi3+7d/9LXn2/Msn1u5nn38mz7YP9Y6SiIharSLPVpp6B1NExOVgqc/eDK3dd5OdNb+a6/0qy+XC2n3V1ztnnr1+Y+2eTfW+ofM3U2t3hN7ZFBFR3ujn5XroncPrq2t5dr7Q76uIiJv+QJ6dLbwCqelMv6+2G6/L6KDr9YE9vLcvz9br3u/bfK13drX2D6zd5brev7bdeN1hCt4UAACJUAAAJEIBAJAIBQBAIhQAAIlQAAAkQgEAkAgFAEAiFAAAiVAAACRCAQCQ5O6jblUejYiI3//PP5Zn/5e1OeLnP/mpPHt6/5+u/+bsW+9Zm7/17qk8++qV3n0TETEdP7fmRyO9o2a38/52eHOp9zbd3o6t3Tdv9GNZF/QOmYiIB8eH1vz5hd7xVCzNrd1Op81y7XUI3U31az+feZ1N3T29D+zesddL1u11rPn5TD/ntarXfbRZ651Qi621OmYL/ZxXC97vsoI3BQBAIhQAAIlQAAAkQgEAkAgFAEAiFAAAiVAAACRCAQCQCAUAQCIUAABJ/ka62fQqA/Za+/Lsj//oj6zd25X+ifnHf/2htfvoWP+UvtX0MnU+1SsdJsuetfvZ8741f/nqjTz77jsPrd2VZk+e7Vb0WoSIiOFYry0prHfW7mrDqzq4uriRZwfjibW7Va/Ks5WyXs0SEdEM/fk5Oe1ZuyNK8uRs4lVo9Hre87bf6cmzozu9siQiYr1a6sOVirV7NNV7MU4Ou9ZuBW8KAIBEKAAAEqEAAEiEAgAgEQoAgEQoAAASoQAASIQCACARCgCARCgAABKhAABIcvfRcjG3Ft/O7uTZzdzrhfmvf/zf5NnVTu9iiYhodvTOpsPvvG/tniz1DpTGwZm1uzj+pTU/e613CD177fXC3DvVO4QOju9bu0eLb+TZXtvrnFmb93gsZvJoQ37Sfq0Qem9TcWt2PFX1XqWtXpMUERHLjd4JtNV/fiIi4uLK6/fqdPRerevLa2t3r6t3pB319NmIiMp0Lc/eXt9auxW8KQAAEqEAAEiEAgAgEQoAgEQoAAASoQAASIQCACARCgCARCgAABKhAABI8nfm05H3ifl6oVdXHHaa1u6K3qIQv/WH/8XaPd7oNRd7e/pn9BERlaleLTHfeXn94K23rfnD4yP9WEYDa3enqV+g5dSrOLl//7F+HK2Gtfvy2RNrfrfV6wgenernOyJiOtErN/p3Y2v3aGPUYhS9notaq67PVr17fDT27pWmUXNR77St3eWKXqHSv7qxdq8L+uzobmjtVvCmAABIhAIAIBEKAIBEKAAAEqEAAEiEAgAgEQoAgEQoAAASoQAASIQCACARCgCAJHcfjW+9/o5uV+8zWm+8fpWL64F+HPd+ZO0ubfTulsGbc2v3dq3/P1u9PWt3u9yx5nt7Z/LsriDfJhERsVov5dlp3+vUKiz1TqDpdGTtni/1LqOIiMJmK89eXNxau6dGT1a5rt+zERGdjn6v3Pa9bp3GTj+W6VS/lhERJaMTKCJicncnz27M36Db2UyeXa68+2r/QO9fa9S841bwpgAASIQCACARCgCARCgAABKhAABIhAIAIBEKAIBEKAAAEqEAAEiEAgAgyf0Fa+OT/oiIvU5Lnj1/dWHtLjX0z8BjsrB2zxf6fLHkZWpxp3+SPjbrHyrlkjW/bunztUbD2l0u6H0ElX2vnmN4odc/rO/Mc1itWvMr45EoVbyqEKcAojhfWbtrNX3+7MGhtdupCpkuvOPuD7zakhffvJJnK+azfHzvWJ49PDyyds+M6/nquVe1o+BNAQCQCAUAQCIUAACJUAAAJEIBAJAIBQBAIhQAAIlQAAAkQgEAkAgFAEAiFAAASS5kWZjdR7OF3oHibY6Yz5by7PLa678p1pry7GQ4tHbPF3qjTa2uH0dERLFYseaHQ71HpjSaWLvLK333uH9l7T53+mym3rU/29f7uiIiWmW946k/Glu7t2u9h+n56xtr98K4D49W3jlx+qDma/03IiLibnJnza+3+t+8xar3/EyM69ntdq3d44l+fUp1owdOxJsCACARCgCARCgAABKhAABIhAIAIBEKAIBEKAAAEqEAAEiEAgAgEQoAgCTXXMRuYy0eGhUQw75eixARsa7oh323mlq7q3N9fnCuVy5ERGyX+u57D8+s3b1T71P66xf/T5795S8+s3ZvVvpn+u2Sd1+N7/R6gf2jPWt3rXNkzc+NZ+L8amDtLhf1vojHj+5ZuydjvbZkNTbrOWp1efbi3Ks4iZL3N+zdSP9defTeW9bu3n5Pnn3+7IW1+0G3Ic+evfPI2q3gTQEAkAgFAEAiFAAAiVAAACRCAQCQCAUAQCIUAACJUAAAJEIBAJAIBQBAIhQAAEkuERqPZtbiVkvv74iC3mUUEbHe6Vk2m3q9Stc3ep9Rs1Swdh90avJsaXpn7S4ul9Z8q9OTZ49PTqzdndJCnp1cv/Z23z+UZ6dbvT8oIuJnH31qzd9N9Gei3elYu8uVijz7/e9+x9r96ad/J8+Wyvo9GxFxeKL3MJVrTWv3rliy5p9/oz/Lt0anVkREuab/vo2n+vMQEfH59YU8e9T3jlvBmwIAIBEKAIBEKAAAEqEAAEiEAgAgEQoAgEQoAAASoQAASIQCACARCgCARCgAAJJcOlSInbX45mooz1Zres9LRERs9C6Rjdl9tJlO5dltzetsGk82+nDd61W6G1xb87XQ9z96cGztXsz1a//C6KeJiJgvBvJsre5dnyh53TqHh/v66oq3++LiRp7987/4K2v3/oF+3P2B13n25Jl+PR8/OrV2b9beb9B6MZdnFwWjqy0ihmP9d+L41Ht+Bn39XtmZ95WCNwUAQCIUAACJUAAAJEIBAJAIBQBAIhQAAIlQAAAkQgEAkAgFAEAiFAAASe4B+NWnv7QWb5d6pcOjdx9bu3clvRZjtdE/6Y+IKG70T+OfffnS2r02PtM/vWdWS8wurflee0+erZW9uoh6Vb8+ja5+HBERk6uVPDtfrK3dVzd9a/7s7L5+LDO9miUi4vpaP5Zmu2nt7g8H8uzSPIfj0VieffHSqH2JiHa7bc3XalV5tu5W7Wz181IM7//5/fff1odLNWu3gjcFAEAiFAAAiVAAACRCAQCQCAUAQCIUAACJUAAAJEIBAJAIBQBAIhQAAIlQAAAkudTm0//71FpcL+g9P7dTvS8lIqJdr8uzk5XXaTJf6x01s9HI2t3d03t+Chu9tyUiYt6fWPObstHH0mxZu+/6d/Ls9c2ttbs/HMqz243XObNceT0/o8lMnp1Mp9buXse4V0re33bLpf7/3BgdPxERlZrek7WJkrV7vvD6o06OjuTZ6cT7DVoZnVCXE71PLSKiWtHP4V7H6yVT8KYAAEiEAgAgEQoAgEQoAAASoQAASIQCACARCgCARCgAABKhAABIhAIAIMnfSK/MyoCVMb67HVi7F82avju8mou18f8cj7xP449P9M/u3bg+v/DqIkpl/byMR16Fxpu+XkXx4vm5tXs20ysDSmWvAqDdblrzm+1Wnl2vzbqIql4B0Tefn7uxXs/RNCtOCsZ9+/jhvrW7vde25kcj/V4Zmvd4vanfK/W6/nsVEfHlk6/k2aPTY2u3gjcFAEAiFAAAiVAAACRCAQCQCAUAQCIUAACJUAAAJEIBAJAIBQBAIhQAAIlQAAAkuRzm8rZvLS7tqvLsZuv1KnW7HXm23fH6UjZGR0297vXCdLt78mzJKZGJiFqtbs1f9UfybKfrncP1ZiXPVr1qqlhv9E6gfeN8R0T83u/+0Jp/8tVzefanP3tp7d4Z13/pFI1FxG63k2fvxl6/13ffO5NnP3j3kbV7UdJ/UyIifjV4Ks8ul1431fGJ/kzMjb6uiIjpXD+WwwOvP0rBmwIAIBEKAIBEKAAAEqEAAEiEAgAgEQoAgEQoAAASoQAASIQCACARCgCAJNdc1GreJ+btZk+ePTBqKyIiWh3j0+7i1tpdLOufmLdaDWv34PpKnu0bVREREUenh9Z8tarXYixWXgVApaRXUbx1dmztHvaH+nBZvr0jIqJZ8+bbzZo8W6/o5yQi4ma8lGfXRjVLRIRzJG3zHt9r6fUPg4lX//A3n35szbc7+u/Kwwen1u5vGfftR588sXb3B3oFzRdPnlm7FbwpAAASoQAASIQCACARCgCARCgAABKhAABIhAIAIBEKAIBEKAAAEqEAAEiEAgAgFXa73e6f+yAAAP8y8KYAAEiEAgAgEQoAgEQoAAASoQAASIQCACARCgCARCgAABKhAABI/x+4IjFL7HCuIAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_img(val_imgs[np.random.choice(val_imgs.shape[0])])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "afa2d72a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 10min 26s, sys: 4min 23s, total: 14min 49s\n",
      "Wall time: 33.1 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "# Get softmax scores\n",
    "with torch.no_grad():\n",
    "    logits = model(torch.from_numpy(val_imgs))\n",
    "    \n",
    "softmax_scores = torch.nn.functional.softmax(logits,dim=1)\n",
    "softmax_scores = softmax_scores.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "59ad2db4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Save softmax scores to ./.cache/best-cifar100-model-fracval=0.7-valsoftmax_frac=0.7.npy\n"
     ]
    }
   ],
   "source": [
    "# Save softmax scores\n",
    "pth = './.cache/' + config['model_filename'] + f'-valsoftmax_frac={config[\"frac_val\"]}.npy'\n",
    "\n",
    "np.save(pth, softmax_scores)\n",
    "print(f'Save softmax scores to {pth}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "114c1685",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
