{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using downloaded and verified file: data/train_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Starting training of model  model_MSD_ERM_51.pt  on SVHN\n",
      "Epoch: 1 \tTraining Loss: 1.861225 \tValidation Loss: 1.737258\n",
      "Epoch: 1 \tTraining accuracy: 0.364570 \tValidation accuracy: 0.415057\n",
      "Validation loss decreased (inf --> 1.737258). \n",
      "Epoch: 2 \tTraining Loss: 1.711976 \tValidation Loss: 1.675906\n",
      "Epoch: 2 \tTraining accuracy: 0.425622 \tValidation accuracy: 0.440448\n",
      "Validation loss decreased (1.737258 --> 1.675906). \n",
      "Epoch: 3 \tTraining Loss: 1.671034 \tValidation Loss: 1.648605\n",
      "Epoch: 3 \tTraining accuracy: 0.440552 \tValidation accuracy: 0.451915\n",
      "Validation loss decreased (1.675906 --> 1.648605). \n",
      "Epoch: 4 \tTraining Loss: 1.645640 \tValidation Loss: 1.636547\n",
      "Epoch: 4 \tTraining accuracy: 0.450193 \tValidation accuracy: 0.455464\n",
      "Validation loss decreased (1.648605 --> 1.636547). \n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_19133/3435468100.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m    402\u001b[0m                 \u001b[0;31m###################\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    403\u001b[0m                 \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 404\u001b[0;31m                 \u001b[0;32mfor\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    405\u001b[0m                     \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    406\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/cvmfs/ai.mila.quebec/apps/arch/distro/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    515\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    516\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 517\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    518\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    519\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/cvmfs/ai.mila.quebec/apps/arch/distro/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    555\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    556\u001b[0m         \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 557\u001b[0;31m         \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    558\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    559\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/cvmfs/ai.mila.quebec/apps/arch/distro/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m     42\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     43\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     45\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/cvmfs/ai.mila.quebec/apps/arch/distro/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     42\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     43\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     45\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/cvmfs/ai.mila.quebec/apps/arch/distro/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/utils/data/dataset.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m    328\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    329\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 330\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    331\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    332\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__len__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/py3.8/lib/python3.7/site-packages/torchvision/datasets/svhn.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m     98\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     99\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 100\u001b[0;31m             \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    102\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/py3.8/lib/python3.7/site-packages/torchvision/transforms/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m     58\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     59\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m             \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     61\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/cvmfs/ai.mila.quebec/apps/arch/distro/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    887\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    888\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    890\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    891\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/py3.8/lib/python3.7/site-packages/torchvision/transforms/transforms.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m    624\u001b[0m             \u001b[0mPIL\u001b[0m \u001b[0mImage\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mRandomly\u001b[0m \u001b[0mflipped\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    625\u001b[0m         \"\"\"\n\u001b[0;32m--> 626\u001b[0;31m         \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    627\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhflip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    628\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "## Training loop for transferring to CIFAR100 or SVHN\n",
    "\n",
    "# clear memory\n",
    "from IPython import get_ipython\n",
    "get_ipython().magic('reset -sf') \n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import time\n",
    "timer = 0\n",
    "\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import torch.nn as nn\n",
    "\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "\n",
    "import os, random\n",
    "\n",
    "# import argparse\n",
    "\n",
    "# argument_parser = argparse.ArgumentParser()\n",
    "\n",
    "# argument_parser.add_argument(\"--lr_init\", type=float, help=\"Initial learning rate value, default=0.01. CAREFUL: this will be divided by beta, since the ERM term is multiplied by beta in the objective.\")\n",
    "\n",
    "# parsed_args = argument_parser.parse_args()\n",
    "\n",
    "\n",
    "# Make sure validation splits are the same at all time (e.g. even after loading)\n",
    "seed = 0\n",
    "\n",
    "def seed_init_fn(seed=seed):\n",
    "   np.random.seed(seed)\n",
    "   random.seed(seed)\n",
    "   torch.manual_seed(seed)\n",
    "   return\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "num_workers = 0\n",
    "# Make sure test_data is a multiple of batch_size_test\n",
    "batch_size_train_and_valid = 128\n",
    "batch_size_test = 200\n",
    "\n",
    "# proportion of full training set used for validation\n",
    "valid_size = 0.2\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "class BasicBlock(nn.Module):\n",
    "    expansion = 1\n",
    "\n",
    "    def __init__(self, in_planes, planes, stride=1):\n",
    "        super(BasicBlock, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_planes != self.expansion * planes:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(self.expansion * planes)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.bn2(self.conv2(out))\n",
    "        out += self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class Bottleneck(nn.Module):\n",
    "    expansion = 4\n",
    "\n",
    "    def __init__(self, in_planes, planes, stride=1):\n",
    "        super(Bottleneck, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)\n",
    "        self.bn3 = nn.BatchNorm2d(self.expansion * planes)\n",
    "\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_planes != self.expansion * planes:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(self.expansion * planes)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = F.relu(self.bn2(self.conv2(out)))\n",
    "        out = self.bn3(self.conv3(out))\n",
    "        out += self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class ResNet(nn.Module):\n",
    "    def __init__(self, block, num_blocks, num_classes=10):\n",
    "        super(ResNet, self).__init__()\n",
    "        self.in_planes = 64\n",
    "\n",
    "        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(64)\n",
    "        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n",
    "        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n",
    "        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n",
    "        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n",
    "        self.linear = nn.Linear(512 * block.expansion, num_classes)\n",
    "\n",
    "    def _make_layer(self, block, planes, num_blocks, stride):\n",
    "        strides = [stride] + [1] * (num_blocks - 1)\n",
    "        layers = []\n",
    "        for stride in strides:\n",
    "            layers.append(block(self.in_planes, planes, stride))\n",
    "            self.in_planes = planes * block.expansion\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.layer1(out)\n",
    "        out = self.layer2(out)\n",
    "        out = self.layer3(out)\n",
    "        out = self.layer4(out)\n",
    "        out = F.avg_pool2d(out, 4)\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.linear(out)\n",
    "#         print(x.size(), out.size())\n",
    "        return out\n",
    "\n",
    "\n",
    "def ResNet18():\n",
    "    return ResNet(BasicBlock, [2, 2, 2, 2])\n",
    "\n",
    "\n",
    "def prepare_model_for_finetuning(model, model_path, model_filename, dataset_name, num_output_classes, device=device):\n",
    "    is_dataparallel = False\n",
    "    if str(device) == \"cuda\" and torch.cuda.device_count() > 1:\n",
    "        print(\"Using DataParallel\")\n",
    "        model = torch.nn.DataParallel(model)\n",
    "        is_dataparallel = True\n",
    "        \n",
    "    checkpoint = torch.load(model_path)\n",
    "\n",
    "    model_successfully_loaded = False\n",
    "    try: \n",
    "        # Some models have been saved as DataParallel. For those we need to load to model, for others to model.module. Lessons\n",
    "        # for the next project: always save module. :) Although this would still be a problem for for loop such as here as even\n",
    "        # if I DataParallel in the for loop, the next iteration of the for loop tries to load non-DataParallel to a DataParalleled model\n",
    "        try:\n",
    "            model.load_state_dict(checkpoint['current_model'])\n",
    "            print(\"Successfully loaded onto model.\")\n",
    "        except:\n",
    "            print(\"Failed to load model onto model, attempting to load onto model.module...\")\n",
    "            model.module.load_state_dict(checkpoint['current_model'])\n",
    "            print(\"Successfully loaded onto model.module.\")\n",
    "        model_successfully_loaded = True\n",
    "        print(\"model_successfully_loaded:\", model_successfully_loaded, flush=True)\n",
    "    except:\n",
    "        print(\"Model not stored on 'current_model' key\")\n",
    "        model_successfully_loaded = False\n",
    "    if not model_successfully_loaded:\n",
    "        try:\n",
    "            # Loading the PAT model is slightly different\n",
    "            checkpoint = torch.load(model_path)\n",
    "            try:\n",
    "                model.load_state_dict(checkpoint['model'])\n",
    "            except:\n",
    "                print(\"Failed to load model onto model, attempting to load onto model.module...\")\n",
    "                model.module.load_state_dict(checkpoint['model'])\n",
    "                print(\"Successfully loaded onto model.module.\")\n",
    "        except:\n",
    "            raise ValueError(\"Did not succeed in loading model -- check the key used to store model in the checkpoint loaded.\")\n",
    "        \n",
    "\n",
    "    if is_dataparallel:\n",
    "        for param in model.module.parameters():\n",
    "            param.requires_grad = False\n",
    "        model.module.linear = nn.Linear(512 * BasicBlock.expansion, num_output_classes)\n",
    "    else:\n",
    "        for param in model.parameters():\n",
    "            param.requires_grad = False\n",
    "        model.linear = nn.Linear(512 * BasicBlock.expansion, num_output_classes)        \n",
    "\n",
    "\n",
    "    model.to(device)\n",
    "    print(\"Loaded model \", model_filename, \" on \" + dataset_name)\n",
    "    return model\n",
    "\n",
    "\n",
    "def prepare_train_and_valid_dataloader(dataset_name, batch_size_train_and_valid=batch_size_train_and_valid, seed=seed):\n",
    "    if dataset_name == \"CIFAR100\":\n",
    "        transform_train = transforms.Compose([\n",
    "            transforms.RandomCrop(32, padding=4),\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))])\n",
    "\n",
    "        train_and_valid_data = datasets.CIFAR100(root = \"data\", train = True, download = True, transform = transform_train)\n",
    "\n",
    "        lr_init = 0.1\n",
    "        num_output_classes = 100\n",
    "\n",
    "    elif dataset_name == \"SVHN\":\n",
    "        transform_train = transforms.Compose([\n",
    "            transforms.RandomCrop(32, padding=4),\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) \n",
    "\n",
    "        train_and_valid_data = datasets.SVHN(root = \"data\", split=\"train\", download = True, transform = transform_train)\n",
    "\n",
    "        lr_init = 0.1\n",
    "        num_output_classes = 10\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"Unsupported dataset name. Supported datasets are CIFAR100, SVHN. You entered:\", dataset_name)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    num_valid_samples = int(np.floor(valid_size * len(train_and_valid_data)))\n",
    "    num_train_samples = len(train_and_valid_data) - num_valid_samples\n",
    "    train_data, valid_data = torch.utils.data.random_split(train_and_valid_data, [num_train_samples, num_valid_samples], generator=torch.Generator().manual_seed(seed))\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size_train_and_valid)\n",
    "    valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = batch_size_train_and_valid)\n",
    "    return train_loader, valid_loader, lr_init, num_output_classes\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Careful: this is about loading pretrained models, not checkpoints of the model during CIFAR100 training !\n",
    "WORKING_DIR = \"results/CIFAR10/\"\n",
    "TRAINED_MODEL_PATH = WORKING_DIR + \"models/\"\n",
    "for root, dirs, files in os.walk(TRAINED_MODEL_PATH):\n",
    "    model_filenames = files\n",
    "    model_paths = [TRAINED_MODEL_PATH + file for file in files]\n",
    "\n",
    "\n",
    "    \n",
    "num_epochs = 30    \n",
    "# Controls how many time we repeat finetuning/training of models to get avg and std since it's not too expensive\n",
    "num_training_loops = 3\n",
    "\n",
    "dataset_names = [\"SVHN\", \"CIFAR100\"]\n",
    "\n",
    "for dataset_name in dataset_names:\n",
    "    TRAINING_OUTPUT_ROOT = \"experiments/\" + dataset_name + '/'\n",
    "    for model_num, model_path in enumerate(model_paths):\n",
    "        for training_loop in range(0, num_training_loops):\n",
    "            TRAINING_OUTPUT_PATH = TRAINING_OUTPUT_ROOT + model_filenames[model_num] + '/'\n",
    "            os.makedirs(TRAINING_OUTPUT_PATH, exist_ok=True)\n",
    "\n",
    "            # Make sure validation splits are the same at all time (e.g. even after loading)\n",
    "            seed = training_loop\n",
    "            train_loader, valid_loader, lr_init, num_output_classes = prepare_train_and_valid_dataloader(dataset_name, batch_size_train_and_valid=batch_size_train_and_valid, seed=seed)\n",
    "\n",
    "            model = ResNet18()\n",
    "            model.to(device)\n",
    "            model = prepare_model_for_finetuning(model, model_path, model_filenames[model_num], dataset_name, num_output_classes, device)\n",
    "\n",
    "\n",
    "            writer = SummaryWriter(TRAINING_OUTPUT_PATH, comment=\"_loop_\"+str(training_loop))\n",
    "\n",
    "            optimizer = torch.optim.SGD(model.parameters(), lr = lr_init, momentum=0.9, weight_decay=5e-4)\n",
    "            schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20], gamma=0.1)\n",
    "\n",
    "            valid_loss_min = np.Inf\n",
    "            best_epoch = 0\n",
    "\n",
    "\n",
    "            for epoch in range(0, num_epochs):\n",
    "                train_loss = 0\n",
    "                num_train_correct_preds = 0\n",
    "                valid_loss = 0\n",
    "                num_valid_correct_preds = 0\n",
    "\n",
    "                ###################\n",
    "                # Train the model #\n",
    "                ###################\n",
    "                model.train()\n",
    "                for data, label in train_loader:\n",
    "                    data, label = data.to(device), label.to(device)\n",
    "\n",
    "                    optimizer.zero_grad()\n",
    "                    preds = model(data)\n",
    "                    loss = F.cross_entropy(preds, label)\n",
    "                    loss.backward()\n",
    "                    optimizer.step()\n",
    "                    train_loss += loss.item() * data.size(0)\n",
    "                    num_train_correct_preds += (torch.argmax(preds, dim=1) == label).sum().item()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "                ######################    \n",
    "                # Validate the model #\n",
    "                ######################\n",
    "                model.eval()\n",
    "                with torch.no_grad():\n",
    "                    for _, (data, label) in enumerate(valid_loader):\n",
    "                        data, label = data.to(device), label.to(device)\n",
    "                        preds = model(data)\n",
    "                        loss = F.cross_entropy(preds, label)\n",
    "                        valid_loss += loss.item() * data.size(0)\n",
    "                        num_valid_correct_preds += (torch.argmax(preds, dim=1) == label).sum().item()\n",
    "\n",
    "\n",
    "                # Average loss over epoch. Careful about computations where we average over batches; they will have a bias if dataset size not multiple of batch_size\n",
    "                train_loss = train_loss / len(train_loader.sampler)\n",
    "                # Handling validation terms differently based on stopping early or not (because when using the full set its size may not be divisible by batch_size) !\n",
    "                valid_loss = valid_loss / len(valid_loader.sampler)\n",
    "                training_acc = num_train_correct_preds / len(train_loader.sampler)\n",
    "                valid_acc = num_valid_correct_preds / len(valid_loader.sampler)\n",
    "\n",
    "                print('Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}'.format(\n",
    "                    epoch+1, \n",
    "                    train_loss,\n",
    "                    valid_loss\n",
    "                    ))\n",
    "\n",
    "                print('Epoch: {} \\tTraining accuracy: {:.6f} \\tValidation accuracy: {:.6f}'.format(\n",
    "                    epoch+1, \n",
    "                    training_acc,\n",
    "                    valid_acc\n",
    "                    ))\n",
    "\n",
    "                epoch_lr = schedule.get_last_lr()[0]\n",
    "                schedule.step()\n",
    "\n",
    "                if valid_loss <= valid_loss_min:\n",
    "                    print('Validation loss decreased ({:.6f} --> {:.6f}). '.format(\n",
    "                    valid_loss_min,\n",
    "                    valid_loss))\n",
    "                    best_epoch = epoch\n",
    "                    valid_loss_min = valid_loss\n",
    "\n",
    "\n",
    "                    path_of_checkpoint = TRAINING_OUTPUT_PATH + dataset_name + '_loop_' + str(training_loop) + '.pt'\n",
    "\n",
    "                    # lr to start from in checkpoint\n",
    "                    lr_init = schedule.get_last_lr()[0]\n",
    "\n",
    "                    checkpoint = {'current_model': model.module.state_dict(),\n",
    "                                  'optimiser': optimizer.state_dict(),\n",
    "                                  'schedule': schedule.state_dict(),\n",
    "                                  'learning_rate': lr_init,\n",
    "                                  'epoch': epoch + 1,\n",
    "                                  'best_epoch': best_epoch,\n",
    "                                  'seed': seed\n",
    "                                 }\n",
    "\n",
    "                    torch.save(checkpoint, path_of_checkpoint)\n",
    "\n",
    "                writer.add_scalar('Learning_rate', epoch_lr, epoch+1)\n",
    "\n",
    "\n",
    "                writer.add_scalar('Training_loss', train_loss, epoch+1)\n",
    "                writer.add_scalar('Validation_loss', valid_loss, epoch+1)\n",
    "\n",
    "                writer.add_scalar('Training_accuracy', training_acc, epoch+1)\n",
    "                writer.add_scalar('Validation_accuracy', valid_acc, epoch+1)\n",
    "\n",
    "            writer.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt/SVHN_loop_2.pt \tTest Loss: 1.705784\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt/SVHN_loop_1.pt \tTest Loss: 1.702524\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt/SVHN_loop_0.pt \tTest Loss: 1.701157\n",
      "Top 1 test accuracy: 0.424311 std 0.001562\n",
      "Top 2 test accuracy: 0.595331 std 0.001715\n",
      "Top 3 test accuracy: 0.711919 std 0.000835\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_MSD_ERM_51.pt/SVHN_loop_0.pt \tTest Loss: 1.553910\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_MSD_ERM_51.pt/SVHN_loop_1.pt \tTest Loss: 1.549035\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_MSD_ERM_51.pt/SVHN_loop_2.pt \tTest Loss: 1.550922\n",
      "Top 1 test accuracy: 0.491114 std 0.001553\n",
      "Top 2 test accuracy: 0.649611 std 0.001166\n",
      "Top 3 test accuracy: 0.746427 std 0.001134\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_MSD_REx_99.pt/SVHN_loop_0.pt \tTest Loss: 1.570131\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_MSD_REx_99.pt/SVHN_loop_1.pt \tTest Loss: 1.559414\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_MSD_REx_99.pt/SVHN_loop_2.pt \tTest Loss: 1.562667\n",
      "Top 1 test accuracy: 0.480306 std 0.001342\n",
      "Top 2 test accuracy: 0.638035 std 0.001577\n",
      "Top 3 test accuracy: 0.740627 std 0.001421\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_PGD_Linf_103.pt/SVHN_loop_0.pt \tTest Loss: 1.556496\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_PGD_Linf_103.pt/SVHN_loop_1.pt \tTest Loss: 1.544496\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_PGD_Linf_103.pt/SVHN_loop_2.pt \tTest Loss: 1.554428\n",
      "Top 1 test accuracy: 0.491677 std 0.003329\n",
      "Top 2 test accuracy: 0.650763 std 0.003663\n",
      "Top 3 test accuracy: 0.749731 std 0.002618\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_std_ERM_140.pt/SVHN_loop_0.pt \tTest Loss: 1.661725\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_std_ERM_140.pt/SVHN_loop_1.pt \tTest Loss: 1.650074\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_std_ERM_140.pt/SVHN_loop_2.pt \tTest Loss: 1.655900\n",
      "Top 1 test accuracy: 0.447846 std 0.001247\n",
      "Top 2 test accuracy: 0.617176 std 0.003194\n",
      "Top 3 test accuracy: 0.726081 std 0.001826\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_std_REx_110.pt/SVHN_loop_0.pt \tTest Loss: 1.587029\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_std_REx_110.pt/SVHN_loop_1.pt \tTest Loss: 1.584581\n",
      "Using downloaded and verified file: data/test_32x32.mat\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_std_REx_110.pt/SVHN_loop_2.pt \tTest Loss: 1.590032\n",
      "Top 1 test accuracy: 0.470690 std 0.000943\n",
      "Top 2 test accuracy: 0.636793 std 0.000596\n",
      "Top 3 test accuracy: 0.739538 std 0.001124\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt/CIFAR100_loop_0.pt \tTest Loss: 2.863767\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt/CIFAR100_loop_1.pt \tTest Loss: 2.865498\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt/CIFAR100_loop_2.pt \tTest Loss: 2.865649\n",
      "Top 1 test accuracy: 0.288733 std 0.001320\n",
      "Top 2 test accuracy: 0.410667 std 0.001617\n",
      "Top 3 test accuracy: 0.490433 std 0.002026\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_MSD_ERM_51.pt/CIFAR100_loop_0.pt \tTest Loss: 2.959761\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_MSD_ERM_51.pt/CIFAR100_loop_1.pt \tTest Loss: 2.967309\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_MSD_ERM_51.pt/CIFAR100_loop_2.pt \tTest Loss: 2.962087\n",
      "Top 1 test accuracy: 0.283767 std 0.001401\n",
      "Top 2 test accuracy: 0.400267 std 0.002201\n",
      "Top 3 test accuracy: 0.472467 std 0.000551\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_MSD_REx_99.pt/CIFAR100_loop_0.pt \tTest Loss: 2.977838\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_MSD_REx_99.pt/CIFAR100_loop_1.pt \tTest Loss: 2.982871\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_MSD_REx_99.pt/CIFAR100_loop_2.pt \tTest Loss: 2.978440\n",
      "Top 1 test accuracy: 0.277100 std 0.001389\n",
      "Top 2 test accuracy: 0.388133 std 0.003758\n",
      "Top 3 test accuracy: 0.465033 std 0.001704\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_PGD_Linf_103.pt/CIFAR100_loop_0.pt \tTest Loss: 2.927328\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_PGD_Linf_103.pt/CIFAR100_loop_1.pt \tTest Loss: 2.933451\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_PGD_Linf_103.pt/CIFAR100_loop_2.pt \tTest Loss: 2.929244\n",
      "Top 1 test accuracy: 0.293267 std 0.001079\n",
      "Top 2 test accuracy: 0.407667 std 0.001305\n",
      "Top 3 test accuracy: 0.482500 std 0.001587\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_std_ERM_140.pt/CIFAR100_loop_0.pt \tTest Loss: 3.121037\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_std_ERM_140.pt/CIFAR100_loop_1.pt \tTest Loss: 3.137564\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_std_ERM_140.pt/CIFAR100_loop_2.pt \tTest Loss: 3.125723\n",
      "Top 1 test accuracy: 0.258667 std 0.002155\n",
      "Top 2 test accuracy: 0.363700 std 0.002030\n",
      "Top 3 test accuracy: 0.431000 std 0.002152\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_std_REx_110.pt/CIFAR100_loop_0.pt \tTest Loss: 3.016966\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_std_REx_110.pt/CIFAR100_loop_1.pt \tTest Loss: 3.020215\n",
      "Files already downloaded and verified\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_std_REx_110.pt/CIFAR100_loop_2.pt \tTest Loss: 3.005564\n",
      "Top 1 test accuracy: 0.277400 std 0.001929\n",
      "Top 2 test accuracy: 0.388567 std 0.000289\n",
      "Top 3 test accuracy: 0.461767 std 0.002676\n"
     ]
    }
   ],
   "source": [
    "## Eval loop for fine tuned models on CIFAR100 or SVHN\n",
    "\n",
    "# clear memory\n",
    "from IPython import get_ipython\n",
    "get_ipython().magic('reset -sf') \n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import time\n",
    "timer = 0\n",
    "\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import torch.nn as nn\n",
    "\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "\n",
    "import os, random\n",
    "\n",
    "# import argparse\n",
    "\n",
    "# argument_parser = argparse.ArgumentParser()\n",
    "\n",
    "# argument_parser.add_argument(\"--lr_init\", type=float, help=\"Initial learning rate value, default=0.01. CAREFUL: this will be divided by beta, since the ERM term is multiplied by beta in the objective.\")\n",
    "\n",
    "# parsed_args = argument_parser.parse_args()\n",
    "\n",
    "\n",
    "# Make sure validation splits are the same at all time (e.g. even after loading)\n",
    "seed = 0\n",
    "\n",
    "def seed_init_fn(seed=seed):\n",
    "   np.random.seed(seed)\n",
    "   random.seed(seed)\n",
    "   torch.manual_seed(seed)\n",
    "   return\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "num_workers = 0\n",
    "# Make sure test_data is a multiple of batch_size_test\n",
    "batch_size_train_and_valid = 128\n",
    "batch_size_test = 200\n",
    "\n",
    "# proportion of full training set used for validation\n",
    "valid_size = 0.2\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "class BasicBlock(nn.Module):\n",
    "    expansion = 1\n",
    "\n",
    "    def __init__(self, in_planes, planes, stride=1):\n",
    "        super(BasicBlock, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_planes != self.expansion * planes:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(self.expansion * planes)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.bn2(self.conv2(out))\n",
    "        out += self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class Bottleneck(nn.Module):\n",
    "    expansion = 4\n",
    "\n",
    "    def __init__(self, in_planes, planes, stride=1):\n",
    "        super(Bottleneck, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)\n",
    "        self.bn3 = nn.BatchNorm2d(self.expansion * planes)\n",
    "\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_planes != self.expansion * planes:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(self.expansion * planes)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = F.relu(self.bn2(self.conv2(out)))\n",
    "        out = self.bn3(self.conv3(out))\n",
    "        out += self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class ResNet(nn.Module):\n",
    "    def __init__(self, block, num_blocks, num_classes=10):\n",
    "        super(ResNet, self).__init__()\n",
    "        self.in_planes = 64\n",
    "\n",
    "        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(64)\n",
    "        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n",
    "        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n",
    "        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n",
    "        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n",
    "        self.linear = nn.Linear(512 * block.expansion, num_classes)\n",
    "\n",
    "    def _make_layer(self, block, planes, num_blocks, stride):\n",
    "        strides = [stride] + [1] * (num_blocks - 1)\n",
    "        layers = []\n",
    "        for stride in strides:\n",
    "            layers.append(block(self.in_planes, planes, stride))\n",
    "            self.in_planes = planes * block.expansion\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.layer1(out)\n",
    "        out = self.layer2(out)\n",
    "        out = self.layer3(out)\n",
    "        out = self.layer4(out)\n",
    "        out = F.avg_pool2d(out, 4)\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.linear(out)\n",
    "#         print(x.size(), out.size())\n",
    "        return out\n",
    "\n",
    "\n",
    "def ResNet18():\n",
    "    return ResNet(BasicBlock, [2, 2, 2, 2])\n",
    "\n",
    "\n",
    "def model_loading_helper(model, model_path, device=device):\n",
    "    is_dataparallel = False\n",
    "    if str(device) == \"cuda\" and torch.cuda.device_count() > 1:\n",
    "        print(\"Using DataParallel\")\n",
    "        model = torch.nn.DataParallel(model)\n",
    "        is_dataparallel = True\n",
    "        \n",
    "    checkpoint = torch.load(model_path)\n",
    "\n",
    "    model_successfully_loaded = False\n",
    "    try: \n",
    "        # Some models have been saved as DataParallel. For those we need to load to model, for others to model.module. Lessons\n",
    "        # for the next project: always save module. :) Although this would still be a problem for for loop such as here as even\n",
    "        # if I DataParallel in the for loop, the next iteration of the for loop tries to load non-DataParallel to a DataParalleled model\n",
    "        try:\n",
    "            model.load_state_dict(checkpoint['current_model'])\n",
    "            print(\"Successfully loaded onto model.\")\n",
    "        except:\n",
    "            print(\"Failed to load model onto model, attempting to load onto model.module...\")\n",
    "            model.module.load_state_dict(checkpoint['current_model'])\n",
    "            print(\"Successfully loaded onto model.module.\")\n",
    "        model_successfully_loaded = True\n",
    "        print(\"model_successfully_loaded:\", model_successfully_loaded, flush=True)\n",
    "    except:\n",
    "        print(\"Model not stored on 'current_model' key\")\n",
    "        model_successfully_loaded = False\n",
    "    if not model_successfully_loaded:\n",
    "        try:\n",
    "            # Loading the PAT model is slightly different\n",
    "            checkpoint = torch.load(model_path)\n",
    "            try:\n",
    "                model.load_state_dict(checkpoint['model'])\n",
    "            except:\n",
    "                print(\"Failed to load model onto model, attempting to load onto model.module...\")\n",
    "                model.module.load_state_dict(checkpoint['model'])\n",
    "                print(\"Successfully loaded onto model.module.\")\n",
    "        except:\n",
    "            raise ValueError(\"Did not succeed in loading model -- check the key used to store model in the checkpoint loaded.\")\n",
    "    return model, is_dataparallel\n",
    "\n",
    "\n",
    "def prepare_model_for_finetuning(model, model_path, model_filename, dataset_name, num_output_classes, device=device):\n",
    "    model, is_dataparallel = model_loading_helper(model, model_path, device)\n",
    "        \n",
    "\n",
    "    if is_dataparallel:\n",
    "        for param in model.module.parameters():\n",
    "            param.requires_grad = False\n",
    "        model.module.linear = nn.Linear(512 * BasicBlock.expansion, num_output_classes)\n",
    "    else:\n",
    "        for param in model.parameters():\n",
    "            param.requires_grad = False\n",
    "        model.linear = nn.Linear(512 * BasicBlock.expansion, num_output_classes)        \n",
    "\n",
    "\n",
    "    model.to(device)\n",
    "    print(\"Loaded model \", model_filename, \" on \" + dataset_name)\n",
    "    return model\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def prepare_test_dataloader(dataset_name, batch_size_test=batch_size_test, seed=seed):\n",
    "    if dataset_name == \"CIFAR100\":\n",
    "        transform_test = transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))])\n",
    "\n",
    "\n",
    "        test_data = datasets.CIFAR100(root = \"data\", train = False, download = True, transform = transform_test)\n",
    "        num_output_classes = 100\n",
    "\n",
    "    elif dataset_name == \"SVHN\":\n",
    "        transform_test = transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "\n",
    "        test_data = datasets.SVHN(root = \"data\", split=\"test\", download = True, transform = transform_test)\n",
    "        num_output_classes = 10\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"Unsupported dataset name. Supported datasets are CIFAR100, SVHN. You entered:\", dataset_name)\n",
    "\n",
    "\n",
    "    test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size_test, worker_init_fn=seed_init_fn(seed))\n",
    "    return test_loader, num_output_classes\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "num_epochs = 30    \n",
    "# Controls how many time we repeat finetuning/training of models to get avg and std since it's not too expensive\n",
    "num_training_loops = 3\n",
    "top_k = 3\n",
    "\n",
    "dataset_names = [\"SVHN\", \"CIFAR100\"]\n",
    "\n",
    "\n",
    "for dataset_name in dataset_names:\n",
    "    SAVE_DIR = \"results/\" + dataset_name + '/'\n",
    "    TRAINED_MODEL_PATH = \"experiments/\" + dataset_name + '/'\n",
    "    _, base_model_list, _ = next(os.walk(TRAINED_MODEL_PATH))\n",
    "    for base_model in base_model_list:\n",
    "        topk_accuracies_test = []\n",
    "\n",
    "\n",
    "        _, _, base_model_iterates = next(os.walk(TRAINED_MODEL_PATH+base_model))\n",
    "        base_model_iterates = [base_model_iterate for base_model_iterate in base_model_iterates if base_model_iterate.endswith(\".pt\")]\n",
    "        # Iterate over all repeats with different seeds of a given model to aggregate statistics.\n",
    "        for eval_loop, base_model_iterate in enumerate(base_model_iterates):\n",
    "            model_path = TRAINED_MODEL_PATH + base_model + '/' + base_model_iterate\n",
    "            # Make sure validation splits are the same at all time (e.g. even after loading)\n",
    "\n",
    "            seed = eval_loop\n",
    "            test_loader, num_output_classes = prepare_test_dataloader(dataset_name, batch_size_test=batch_size_test, seed=seed)\n",
    "\n",
    "            model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_output_classes)\n",
    "            model.to(device)\n",
    "            model, _ = model_loading_helper(model, model_path, device)\n",
    "\n",
    "\n",
    "            test_loss = 0\n",
    "            # confusion_matrix_test.append(torch.zeros([num_output_classes, num_output_classes]).to(device))\n",
    "            topk_accuracies_test.append(torch.zeros(top_k))\n",
    "\n",
    "\n",
    "            ######################    \n",
    "            # Test the model #\n",
    "            ######################\n",
    "            model.eval()\n",
    "            for _, (data, label) in enumerate(test_loader):\n",
    "                data, label = data.to(device), label.to(device)\n",
    "\n",
    "                with torch.no_grad():\n",
    "                    preds = model(data)#, y=label)\n",
    "                    loss = F.cross_entropy(preds, label)\n",
    "                    test_loss += loss.item() * data.size(0)\n",
    "\n",
    "                    # Update count of number of correct predictions per domain\n",
    "                    # pred_probabilities = F.softmax(preds, dim=1)\n",
    "                    predicted_topk = torch.topk(preds, top_k, dim=1).indices\n",
    "                    for iter_samples, (pred, target) in enumerate(zip(predicted_topk, label)):\n",
    "                        # # confusion_matrix_test[int(target)] += pred_probabilities[iter_samples]\n",
    "                        # confusion_matrix_test[int(target), int(pred[0])] += 1\n",
    "                        for iter_topk in range(0, top_k):\n",
    "                            if target in pred[:iter_topk+1]:\n",
    "                                topk_accuracies_test[eval_loop][iter_topk] += 1\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "            test_loss = test_loss / len(test_loader.sampler)\n",
    "            topk_accuracies_test[eval_loop] /= len(test_loader.sampler)\n",
    "            # confusion_matrix_test[eval_loop] /= len(test_loader.sampler)\n",
    "\n",
    "\n",
    "            print(\"Model: {} \\tTest Loss: {:.6f}\".format(\n",
    "                base_model + '/' + base_model_iterate, \n",
    "                test_loss\n",
    "                ))\n",
    "\n",
    "\n",
    "        # Compute statistics over the runs\n",
    "        topk_accuracies_test_std, topk_accuracies_test_mean = torch.std_mean(torch.stack(topk_accuracies_test), dim=0)\n",
    "\n",
    "        for iter_topk in range(0, top_k):\n",
    "            print(\"Top {} test accuracy: {:.6f} std {:.6f}\".format(\n",
    "                iter_topk+1,\n",
    "                topk_accuracies_test_mean[iter_topk], topk_accuracies_test_std[iter_topk]\n",
    "                ), flush=True)\n",
    "\n",
    "        for i in range(len(topk_accuracies_test)):\n",
    "            topk_accuracies_test[i] = topk_accuracies_test[i].detach().numpy()\n",
    "        # confusion_matrix_test = confusion_matrix_test.cpu().numpy()\n",
    "        results = {}\n",
    "        results[\"topk_accuracies\"] = topk_accuracies_test\n",
    "        results[\"topk_accuracies_mean\"] = topk_accuracies_test_mean.detach().numpy()\n",
    "        results[\"topk_accuracies_std\"] = topk_accuracies_test_std.detach().numpy()\n",
    "        results[\"base_model\"] = base_model\n",
    "        results[\"number_of_iterations\"] = len(base_model_iterates)\n",
    "\n",
    "        # results[\"confusion_matrix\"] = confusion_matrix_test\n",
    "\n",
    "\n",
    "        # df_confmat = pd.DataFrame(confusion_matrix_test, index = domains, columns=domains)\n",
    "        # fig = plt.figure(figsize=(15,10))#, dpi=1200)\n",
    "        # heatmap = sbn.heatmap(df_confmat, annot=True)\n",
    "        # heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right')#, fontsize=15)\n",
    "        # heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right')#, fontsize=15)\n",
    "        # plt.ylabel('True label')\n",
    "        # plt.xlabel('Predicted label')\n",
    "\n",
    "        working_dir_of_save = SAVE_DIR + \"test_accs/\"\n",
    "        os.makedirs(SAVE_DIR + \"test_accs/\", exist_ok=True)\n",
    "        np.save(working_dir_of_save + base_model, results)\n",
    "        # fig.savefig(working_dir_of_save + model_filenames[model_num] + \"_confusion_matrix.pdf\", bbox_inches='tight')\n",
    "        # plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.403580\n",
      "Top 1 test accuracy on corruption brightness: 0.31656\n",
      "Top 2 test accuracy on corruption brightness: 0.4671\n",
      "Top 3 test accuracy on corruption brightness: 0.58354\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.119864\n",
      "Top 1 test accuracy on corruption contrast: 0.35358\n",
      "Top 2 test accuracy on corruption contrast: 0.47242\n",
      "Top 3 test accuracy on corruption contrast: 0.56892\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.551339\n",
      "Top 1 test accuracy on corruption defocus_blur: 0.22878\n",
      "Top 2 test accuracy on corruption defocus_blur: 0.3711\n",
      "Top 3 test accuracy on corruption defocus_blur: 0.48942\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.603515\n",
      "Top 1 test accuracy on corruption elastic_transform: 0.2158\n",
      "Top 2 test accuracy on corruption elastic_transform: 0.36232\n",
      "Top 3 test accuracy on corruption elastic_transform: 0.48408\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.156470\n",
      "Top 1 test accuracy on corruption fog: 0.32502\n",
      "Top 2 test accuracy on corruption fog: 0.47612\n",
      "Top 3 test accuracy on corruption fog: 0.59332\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.667008\n",
      "Top 1 test accuracy on corruption frost: 0.2841\n",
      "Top 2 test accuracy on corruption frost: 0.44744\n",
      "Top 3 test accuracy on corruption frost: 0.58938\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.604510\n",
      "Top 1 test accuracy on corruption gaussian_blur: 0.22176\n",
      "Top 2 test accuracy on corruption gaussian_blur: 0.35922\n",
      "Top 3 test accuracy on corruption gaussian_blur: 0.47902\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 3.133063\n",
      "Top 1 test accuracy on corruption gaussian_noise: 0.13402\n",
      "Top 2 test accuracy on corruption gaussian_noise: 0.2682\n",
      "Top 3 test accuracy on corruption gaussian_noise: 0.39692\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.783115\n",
      "Top 1 test accuracy on corruption glass_blur: 0.17968\n",
      "Top 2 test accuracy on corruption glass_blur: 0.32156\n",
      "Top 3 test accuracy on corruption glass_blur: 0.4421\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 3.149399\n",
      "Top 1 test accuracy on corruption impulse_noise: 0.14694\n",
      "Top 2 test accuracy on corruption impulse_noise: 0.2774\n",
      "Top 3 test accuracy on corruption impulse_noise: 0.3893\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.684133\n",
      "Top 1 test accuracy on corruption jpeg_compression: 0.18642\n",
      "Top 2 test accuracy on corruption jpeg_compression: 0.34114\n",
      "Top 3 test accuracy on corruption jpeg_compression: 0.46632\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.533250\n",
      "Top 1 test accuracy on corruption motion_blur: 0.23226\n",
      "Top 2 test accuracy on corruption motion_blur: 0.3755\n",
      "Top 3 test accuracy on corruption motion_blur: 0.49254\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.565742\n",
      "Top 1 test accuracy on corruption pixelate: 0.20924\n",
      "Top 2 test accuracy on corruption pixelate: 0.35608\n",
      "Top 3 test accuracy on corruption pixelate: 0.47678\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.689274\n",
      "Top 1 test accuracy on corruption saturate: 0.20754\n",
      "Top 2 test accuracy on corruption saturate: 0.35158\n",
      "Top 3 test accuracy on corruption saturate: 0.46622\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 3.105118\n",
      "Top 1 test accuracy on corruption shot_noise: 0.13638\n",
      "Top 2 test accuracy on corruption shot_noise: 0.27388\n",
      "Top 3 test accuracy on corruption shot_noise: 0.40004\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.783720\n",
      "Top 1 test accuracy on corruption snow: 0.22296\n",
      "Top 2 test accuracy on corruption snow: 0.37836\n",
      "Top 3 test accuracy on corruption snow: 0.51164\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.835473\n",
      "Top 1 test accuracy on corruption spatter: 0.18384\n",
      "Top 2 test accuracy on corruption spatter: 0.32502\n",
      "Top 3 test accuracy on corruption spatter: 0.44634\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 3.146637\n",
      "Top 1 test accuracy on corruption speckle_noise: 0.13224\n",
      "Top 2 test accuracy on corruption speckle_noise: 0.26622\n",
      "Top 3 test accuracy on corruption speckle_noise: 0.3916\n",
      "Using DataParallel\n",
      "Failed to load model onto model, attempting to load onto model.module...\n",
      "Successfully loaded onto model.module.\n",
      "model_successfully_loaded: True\n",
      "Model: model_ERM_clean_56.pt \tTest Loss: 2.597282\n",
      "Top 1 test accuracy on corruption zoom_blur: 0.22318\n",
      "Top 2 test accuracy on corruption zoom_blur: 0.36402\n",
      "Top 3 test accuracy on corruption zoom_blur: 0.48698\n"
     ]
    }
   ],
   "source": [
    "### Eval base models on CIFAR-10-C\n",
    "\n",
    "\n",
    "# clear memory\n",
    "from IPython import get_ipython\n",
    "get_ipython().magic('reset -sf') \n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import time\n",
    "timer = 0\n",
    "\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import torch.nn as nn\n",
    "\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "import os, random\n",
    "\n",
    "# import argparse\n",
    "\n",
    "# argument_parser = argparse.ArgumentParser()\n",
    "\n",
    "# argument_parser.add_argument(\"--lr_init\", type=float, help=\"Initial learning rate value, default=0.01. CAREFUL: this will be divided by beta, since the ERM term is multiplied by beta in the objective.\")\n",
    "\n",
    "# parsed_args = argument_parser.parse_args()\n",
    "\n",
    "\n",
    "# Make sure validation splits are the same at all time (e.g. even after loading)\n",
    "seed = 0\n",
    "\n",
    "def seed_init_fn(seed=seed):\n",
    "   np.random.seed(seed)\n",
    "   random.seed(seed)\n",
    "   torch.manual_seed(seed)\n",
    "   return\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "num_workers = 0\n",
    "batch_size_test = 200\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "class CIFAR10C(datasets.VisionDataset):\n",
    "    corruptions = ['brightness', 'contrast', 'defocus_blur', 'elastic_transform',\n",
    "                        'fog', 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur',\n",
    "                        'impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate',\n",
    "                        'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise',\n",
    "                        'zoom_blur']\n",
    "    def __init__(self, root :str, name :str, transform=None, target_transform=None):\n",
    "        assert name in self.corruptions\n",
    "\n",
    "        # Download the dataset if needed\n",
    "        if not os.path.exists(root):\n",
    "            import urllib.request\n",
    "            from tqdm import tqdm\n",
    "            import tarfile\n",
    "            os.mkdir(root)\n",
    "            url = \"https://zenodo.org/record/2535967/files/CIFAR-10-C.tar?download=1\"\n",
    "            file_name = 'cifar-c.tar'\n",
    "            file_path = os.path.join(root, file_name)\n",
    "\n",
    "            # Check if .tar of CIFAR-C already downloaded, and download it if necessary\n",
    "            if not os.path.exists(file_path):\n",
    "                print('Downloading CIFAR-C dataset...')\n",
    "                with tqdm(unit='B', unit_scale=True, desc=file_name, leave=True) as progress_bar:\n",
    "                    urllib.request.urlretrieve(url, file_path, reporthook=lambda blocknum, blocksize, totalsize: progress_bar.update(blocknum * blocksize - progress_bar.n))\n",
    "            else:\n",
    "                print(file_name, \" already downloaded.\")\n",
    "\n",
    "            # Extract the file\n",
    "            with tarfile.open(file_path, 'r') as tar:\n",
    "                tar.extractall(root)\n",
    "            \n",
    "\n",
    "        super(CIFAR10C, self).__init__(root, transform=transform, target_transform=target_transform)\n",
    "        data_path = os.path.join(root, name + '.npy')\n",
    "        target_path = os.path.join(root, 'labels.npy')\n",
    "        \n",
    "        self.data = np.load(data_path)\n",
    "        self.targets = np.load(target_path)\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        img, targets = self.data[index], self.targets[index]\n",
    "        img = Image.fromarray(img)\n",
    "        targets = torch.tensor(targets, dtype=torch.long)\n",
    "        \n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        if self.target_transform is not None:\n",
    "            targets = self.target_transform(targets)\n",
    "            \n",
    "        return img, targets\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "class BasicBlock(nn.Module):\n",
    "    expansion = 1\n",
    "\n",
    "    def __init__(self, in_planes, planes, stride=1):\n",
    "        super(BasicBlock, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_planes != self.expansion * planes:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(self.expansion * planes)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.bn2(self.conv2(out))\n",
    "        out += self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class Bottleneck(nn.Module):\n",
    "    expansion = 4\n",
    "\n",
    "    def __init__(self, in_planes, planes, stride=1):\n",
    "        super(Bottleneck, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)\n",
    "        self.bn3 = nn.BatchNorm2d(self.expansion * planes)\n",
    "\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_planes != self.expansion * planes:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(self.expansion * planes)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = F.relu(self.bn2(self.conv2(out)))\n",
    "        out = self.bn3(self.conv3(out))\n",
    "        out += self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class ResNet(nn.Module):\n",
    "    def __init__(self, block, num_blocks, num_classes=10):\n",
    "        super(ResNet, self).__init__()\n",
    "        self.in_planes = 64\n",
    "\n",
    "        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(64)\n",
    "        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n",
    "        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n",
    "        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n",
    "        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n",
    "        self.linear = nn.Linear(512 * block.expansion, num_classes)\n",
    "\n",
    "    def _make_layer(self, block, planes, num_blocks, stride):\n",
    "        strides = [stride] + [1] * (num_blocks - 1)\n",
    "        layers = []\n",
    "        for stride in strides:\n",
    "            layers.append(block(self.in_planes, planes, stride))\n",
    "            self.in_planes = planes * block.expansion\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.layer1(out)\n",
    "        out = self.layer2(out)\n",
    "        out = self.layer3(out)\n",
    "        out = self.layer4(out)\n",
    "        out = F.avg_pool2d(out, 4)\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.linear(out)\n",
    "#         print(x.size(), out.size())\n",
    "        return out\n",
    "\n",
    "\n",
    "def ResNet18():\n",
    "    return ResNet(BasicBlock, [2, 2, 2, 2])\n",
    "\n",
    "\n",
    "def model_loading_helper(model, model_path, device=device):\n",
    "    is_dataparallel = False\n",
    "    if str(device) == \"cuda\" and torch.cuda.device_count() > 1:\n",
    "        print(\"Using DataParallel\")\n",
    "        model = torch.nn.DataParallel(model)\n",
    "        is_dataparallel = True\n",
    "        \n",
    "    checkpoint = torch.load(model_path)\n",
    "\n",
    "    model_successfully_loaded = False\n",
    "    try: \n",
    "        # Some models have been saved as DataParallel. For those we need to load to model, for others to model.module. Lessons\n",
    "        # for the next project: always save module. :) Although this would still be a problem for for loop such as here as even\n",
    "        # if I DataParallel in the for loop, the next iteration of the for loop tries to load non-DataParallel to a DataParalleled model\n",
    "        try:\n",
    "            model.load_state_dict(checkpoint['current_model'])\n",
    "            print(\"Successfully loaded onto model.\")\n",
    "        except:\n",
    "            print(\"Failed to load model onto model, attempting to load onto model.module...\")\n",
    "            model.module.load_state_dict(checkpoint['current_model'])\n",
    "            print(\"Successfully loaded onto model.module.\")\n",
    "        model_successfully_loaded = True\n",
    "        print(\"model_successfully_loaded:\", model_successfully_loaded, flush=True)\n",
    "    except:\n",
    "        print(\"Model not stored on 'current_model' key\")\n",
    "        model_successfully_loaded = False\n",
    "    if not model_successfully_loaded:\n",
    "        try:\n",
    "            # Loading the PAT model is slightly different\n",
    "            checkpoint = torch.load(model_path)\n",
    "            try:\n",
    "                model.load_state_dict(checkpoint['model'])\n",
    "            except:\n",
    "                print(\"Failed to load model onto model, attempting to load onto model.module...\")\n",
    "                model.module.load_state_dict(checkpoint['model'])\n",
    "                print(\"Successfully loaded onto model.module.\")\n",
    "        except:\n",
    "            raise ValueError(\"Did not succeed in loading model -- check the key used to store model in the checkpoint loaded.\")\n",
    "    return model, is_dataparallel\n",
    "\n",
    "\n",
    "\n",
    "def prepare_test_dataloader(dataset_name, batch_size_test=batch_size_test, seed=seed, corruption_name=None):\n",
    "    if dataset_name == \"CIFAR100\":\n",
    "        transform_test = transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))])\n",
    "\n",
    "\n",
    "        test_data = datasets.CIFAR100(root = \"data\", train = False, download = True, transform = transform_test)\n",
    "        num_output_classes = 100\n",
    "\n",
    "    elif dataset_name == \"SVHN\":\n",
    "        transform_test = transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "\n",
    "        test_data = datasets.SVHN(root = \"data\", split=\"test\", download = True, transform = transform_test)\n",
    "        num_output_classes = 10\n",
    "\n",
    "    elif dataset_name == \"CIFAR-10-C\":\n",
    "        assert corruption_name in CIFAR10C.corruptions\n",
    "        transform_test = transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0.49139968,  0.48215841,  0.44653091), (0.24703223,  0.24348513,  0.26158784))])\n",
    "\n",
    "        test_data = CIFAR10C(root=\"data/CIFAR-10-C/\", name=corruption_name, transform=transform_test)\n",
    "        num_output_classes = 10\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"Unsupported dataset name. Supported datasets are CIFAR100, SVHN, CIFAR-10-C. You entered:\", dataset_name)\n",
    "\n",
    "\n",
    "    test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size_test, worker_init_fn=seed_init_fn(seed))\n",
    "    return test_loader, num_output_classes\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "WORKING_DIR = \"results/CIFAR10/\"\n",
    "TRAINED_MODEL_PATH = WORKING_DIR + \"models/\"\n",
    "for root, dirs, files in os.walk(TRAINED_MODEL_PATH):\n",
    "    model_filenames = files\n",
    "    model_paths = [TRAINED_MODEL_PATH + file for file in files]\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "\n",
    "top_k = 3\n",
    "\n",
    "dataset_name = \"CIFAR-10-C\"\n",
    "SAVE_DIR = \"results/\" + dataset_name + '/'\n",
    "\n",
    "for model_num, model_path in enumerate(model_paths):\n",
    "    topk_accuracies_test = {}\n",
    "    for corruption_name in CIFAR10C.corruptions:\n",
    "\n",
    "\n",
    "\n",
    "        seed = 0\n",
    "        test_loader, num_output_classes = prepare_test_dataloader(dataset_name, batch_size_test=batch_size_test, corruption_name=corruption_name, seed=seed)\n",
    "\n",
    "        model = ResNet18()\n",
    "        model.to(device)\n",
    "        model, _ = model_loading_helper(model, model_path, device)\n",
    "\n",
    "\n",
    "        test_loss = 0\n",
    "        # confusion_matrix_test.append(torch.zeros([num_output_classes, num_output_classes]).to(device))\n",
    "        topk_accuracies_test[corruption_name] = torch.zeros(top_k)\n",
    "\n",
    "\n",
    "        ######################    \n",
    "        # Test the model #\n",
    "        ######################\n",
    "        model.eval()\n",
    "        for _, (data, label) in enumerate(test_loader):\n",
    "            data, label = data.to(device), label.to(device)\n",
    "\n",
    "            with torch.no_grad():\n",
    "                preds = model(data)#, y=label)\n",
    "                loss = F.cross_entropy(preds, label)\n",
    "                test_loss += loss.item() * data.size(0)\n",
    "\n",
    "                # Update count of number of correct predictions per domain\n",
    "                # pred_probabilities = F.softmax(preds, dim=1)\n",
    "                predicted_topk = torch.topk(preds, top_k, dim=1).indices\n",
    "                for iter_samples, (pred, target) in enumerate(zip(predicted_topk, label)):\n",
    "                    # # confusion_matrix_test[int(target)] += pred_probabilities[iter_samples]\n",
    "                    # confusion_matrix_test[int(target), int(pred[0])] += 1\n",
    "                    for iter_topk in range(0, top_k):\n",
    "                        if target in pred[:iter_topk+1]:\n",
    "                            topk_accuracies_test[corruption_name][iter_topk] += 1\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        test_loss = test_loss / len(test_loader.sampler)\n",
    "        topk_accuracies_test[corruption_name] /= len(test_loader.sampler)\n",
    "        # confusion_matrix_test[eval_loop] /= len(test_loader.sampler)\n",
    "        topk_accuracies_test[corruption_name] = topk_accuracies_test[corruption_name].detach().numpy()\n",
    "\n",
    "        print(\"Model: {} \\tTest Loss: {:.6f}\".format(\n",
    "            model_filenames[model_num], \n",
    "            test_loss\n",
    "            ))\n",
    "\n",
    "\n",
    "        for iter_topk in range(0, top_k):\n",
    "            print(\"Top {} test accuracy on corruption {}:\".format(\n",
    "                iter_topk+1,\n",
    "                corruption_name\n",
    "                ), \n",
    "                topk_accuracies_test[corruption_name][iter_topk], flush=True)\n",
    "\n",
    "\n",
    "    # confusion_matrix_test = confusion_matrix_test.cpu().numpy()\n",
    "    results = {}\n",
    "    results[\"topk_accuracies\"] = topk_accuracies_test\n",
    "    results[\"model_name\"] = model_filenames[model_num]\n",
    "    results[\"top_k\"] = top_k\n",
    "\n",
    "    # results[\"confusion_matrix\"] = confusion_matrix_test\n",
    "\n",
    "\n",
    "    # df_confmat = pd.DataFrame(confusion_matrix_test, index = domains, columns=domains)\n",
    "    # fig = plt.figure(figsize=(15,10))#, dpi=1200)\n",
    "    # heatmap = sbn.heatmap(df_confmat, annot=True)\n",
    "    # heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right')#, fontsize=15)\n",
    "    # heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right')#, fontsize=15)\n",
    "    # plt.ylabel('True label')\n",
    "    # plt.xlabel('Predicted label')\n",
    "\n",
    "    working_dir_of_save = SAVE_DIR + \"test_accs/\"\n",
    "    os.makedirs(SAVE_DIR + \"test_accs/\", exist_ok=True)\n",
    "    np.save(working_dir_of_save + model_filenames[model_num], results)\n",
    "    # fig.savefig(working_dir_of_save + model_filenames[model_num] + \"_confusion_matrix.pdf\", bbox_inches='tight')\n",
    "    # plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.8",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
