{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and general settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import os\n",
    "import sys\n",
    "from torch.utils.data import DataLoader\n",
    "from helpers import set_seeds, set_cuda_randomness\n",
    "from init import init_model, init_check\n",
    "from data import init_dataset, split_dataset\n",
    "from modelvshuman import models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set different model parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# General settings\n",
    "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "BATCH_SIZE = 256\n",
    "CUDA = 0  \n",
    "VERBOSE = 1  \n",
    "GLOBAL_SEED = 1312\n",
    "INIT_SEED = 1312\n",
    "\n",
    "# Set seeds and set CUDA to be deterministic or non-deterministic\n",
    "set_seeds(GLOBAL_SEED)\n",
    "set_cuda_randomness(CUDA)\n",
    "val_set = init_dataset(\"ImageNet\", _, _, _, _, train=False)\n",
    "val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=30)  # Do not shuffle to keep same order\n",
    "criterion = nn.CrossEntropyLoss().to(DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run validation set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# InsDis\n",
    "# simclr_resnet50x1\n",
    "# vit_base_patch16_224\n",
    "# resnet50_l2_eps0\n",
    "# bagnet33\n",
    "# CorNet-RT "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.models as models\n",
    "squeezenet = models.squeezenet1_1(pretrained=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['resnet50_trained_on_SIN', 'resnet50_trained_on_SIN_and_IN', 'resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN', 'bagnet9', 'bagnet17', 'bagnet33', 'simclr_resnet50x1_supervised_baseline', 'simclr_resnet50x4_supervised_baseline', 'simclr_resnet50x1', 'simclr_resnet50x2', 'simclr_resnet50x4', 'InsDis', 'MoCo', 'MoCoV2', 'PIRL', 'InfoMin', 'resnet50_l2_eps0', 'resnet50_l2_eps0_01', 'resnet50_l2_eps0_03', 'resnet50_l2_eps0_05', 'resnet50_l2_eps0_1', 'resnet50_l2_eps0_25', 'resnet50_l2_eps0_5', 'resnet50_l2_eps1', 'resnet50_l2_eps3', 'resnet50_l2_eps5', 'efficientnet_b0', 'efficientnet_es', 'efficientnet_b0_noisy_student', 'efficientnet_l2_noisy_student_475', 'transformer_B16_IN21K', 'transformer_B32_IN21K', 'transformer_L16_IN21K', 'transformer_L32_IN21K', 'vit_small_patch16_224', 'vit_base_patch16_224', 'vit_large_patch16_224', 'cspresnet50', 'cspresnext50', 'cspdarknet53', 'darknet53', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn131', 'dpn107', 'hrnet_w18_small', 'hrnet_w18_small', 'hrnet_w18_small_v2', 'hrnet_w18', 'hrnet_w30', 'hrnet_w40', 'hrnet_w44', 'hrnet_w48', 'hrnet_w64', 'selecsls42', 'selecsls84', 'selecsls42b', 'selecsls60', 'selecsls60b', 'clip', 'clipRN50', 'resnet50_swsl', 'ResNeXt101_32x16d_swsl', 'BiTM_resnetv2_50x1', 'BiTM_resnetv2_50x3', 'BiTM_resnetv2_101x1', 'BiTM_resnetv2_101x3', 'BiTM_resnetv2_152x2', 'BiTM_resnetv2_152x4', 'resnet50_clip_hard_labels', 'resnet50_clip_soft_labels']\n"
     ]
    }
   ],
   "source": [
    "print(models.list_models(\"pytorch\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: \"https://s3.amazonaws.com/cornet-models/cornet_rt-933c001c.pth\" to /home/wichmann/lschulzebuschoff43/.cache/torch/hub/checkpoints/cornet_rt-933c001c.pth\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "26f71d22800a4095a915e6576ad6bbcb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0.00/39.8M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# cornetrt = torch.utils.model_zoo.load_url(\"https://s3.amazonaws.com/cornet-models/cornet_rt-933c001c.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from cornet_rt import CORnet_RT\n",
    "# from cornet_rt import HASH as HASH_RT\n",
    "\n",
    "# def get_model(model_letter, pretrained=False, map_location=None, **kwargs):\n",
    "#     model_letter = model_letter.upper()\n",
    "#     model_hash = globals()[f'HASH_{model_letter}']\n",
    "#     model = globals()[f'CORnet_{model_letter}'](**kwargs)\n",
    "#     model = torch.nn.DataParallel(model)\n",
    "#     if pretrained:\n",
    "#         url = f'https://s3.amazonaws.com/cornet-models/cornet_{model_letter.lower()}-{model_hash}.pth'\n",
    "#         ckpt_data = torch.utils.model_zoo.load_url(url, map_location=map_location)\n",
    "#         model.load_state_dict(ckpt_data['state_dict'])\n",
    "#     return model\n",
    "\n",
    "\n",
    "# def cornet_rt(pretrained=False, map_location=None, times=5):\n",
    "#     return get_model('rt', pretrained=pretrained, map_location=map_location, times=times)\n",
    "\n",
    "# cornet = get_model(\"RT\", pretrained=True)\n",
    "# model = cornet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modelvshuman.models.pytorch.model_zoo import resnet50_swsl\n",
    "model = resnet50_swsl(\"resnet50_swsl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Test accuracy:  0.5818\n"
     ]
    }
   ],
   "source": [
    "# Set filename\n",
    "filename = f\"./results/squeezenet1_1/NUM1/\"\n",
    "\n",
    "# Pre-allocate arrays for training and results (+1 for test set before model enters training)\n",
    "loss_val = torch.zeros(1).to(DEVICE)\n",
    "acc_val = torch.zeros(1).to(DEVICE)\n",
    "\n",
    "\n",
    "# Disable gradient for test dataset\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "\n",
    "    # Re-set seed to global seed\n",
    "    set_seeds(GLOBAL_SEED)\n",
    "\n",
    "    # Run test set\n",
    "    for i, (images, targets) in enumerate(val_loader):\n",
    "\n",
    "        # Load images and targets onto GPU\n",
    "        images = images.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "\n",
    "        # Get output and loss\n",
    "        output = model(images)\n",
    "\n",
    "        # Compute accuracy\n",
    "        acc_val[0] += torch.sum(torch.eq(targets, torch.argmax(output, dim=1))) / len(val_loader.dataset)\n",
    "\n",
    "        # Prepare outputs to be saved for each epoch\n",
    "        if i == 0:\n",
    "            epoch_output = output\n",
    "            epoch_targets = targets\n",
    "        else:\n",
    "            epoch_output = torch.cat((epoch_output, output))\n",
    "            epoch_targets = torch.cat((epoch_targets, targets))\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
