{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and general settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import os\n",
    "import sys\n",
    "from torch.utils.data import DataLoader\n",
    "from helpers import set_seeds, set_cuda_randomness\n",
    "from init import init_model, init_check\n",
    "from data import init_dataset, split_dataset\n",
    "from modelvshuman import models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set different model parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# General settings\n",
    "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "BATCH_SIZE = 256\n",
    "CUDA = 0  \n",
    "VERBOSE = 1  \n",
    "GLOBAL_SEED = 1312\n",
    "INIT_SEED = 1312\n",
    "\n",
    "# Set seeds and set CUDA to be deterministic or non-deterministic\n",
    "set_seeds(GLOBAL_SEED)\n",
    "set_cuda_randomness(CUDA)\n",
    "val_set = init_dataset(\"ImageNet\", _, _, _, _, train=False)\n",
    "val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=30)  # Do not shuffle to keep same order\n",
    "criterion = nn.CrossEntropyLoss().to(DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cornet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Either execute CorNet code or modelvshuman code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get Cornet\n",
    "# cornetrt = torch.utils.model_zoo.load_url(\"https://s3.amazonaws.com/cornet-models/cornet_rt-933c001c.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# filename = f\"./results/Sota/CorNet/NUM1/\"\n",
    "# from cornet_rt import CORnet_RT\n",
    "# from cornet_rt import HASH as HASH_RT\n",
    "\n",
    "# def get_model(model_letter, pretrained=False, map_location=None, **kwargs):\n",
    "#     model_letter = model_letter.upper()\n",
    "#     model_hash = globals()[f'HASH_{model_letter}']\n",
    "#     model = globals()[f'CORnet_{model_letter}'](**kwargs)\n",
    "#     model = torch.nn.DataParallel(model)\n",
    "#     if pretrained:\n",
    "#         url = f'https://s3.amazonaws.com/cornet-models/cornet_{model_letter.lower()}-{model_hash}.pth'\n",
    "#         ckpt_data = torch.utils.model_zoo.load_url(url, map_location=map_location)\n",
    "#         model.load_state_dict(ckpt_data['state_dict'])\n",
    "#     return model\n",
    "\n",
    "\n",
    "# def cornet_rt(pretrained=False, map_location=None, times=5):\n",
    "#     return get_model('rt', pretrained=pretrained, map_location=map_location, times=times)\n",
    "\n",
    "# cornet = get_model(\"RT\", pretrained=True)\n",
    "# model = cornet\n",
    "\n",
    "# #Set filename\n",
    "# filename = f\"./results/Rebuttal/CorNet/NUM1/\"\n",
    "\n",
    "# # Pre-allocate arrays for training and results (+1 for test set before model enters training)\n",
    "# loss_val = torch.zeros(1).to(DEVICE)\n",
    "# acc_val = torch.zeros(1).to(DEVICE)\n",
    "\n",
    "\n",
    "# # Disable gradient for test dataset\n",
    "# model.eval()\n",
    "# with torch.no_grad():\n",
    "\n",
    "#     # Re-set seed to global seed\n",
    "#     set_seeds(GLOBAL_SEED)\n",
    "\n",
    "#     # Run test set\n",
    "#     for i, (images, targets) in enumerate(val_loader):\n",
    "\n",
    "#         # Load images and targets onto GPU\n",
    "#         images = images.to(DEVICE)\n",
    "#         targets = targets.to(DEVICE)\n",
    "\n",
    "#         # Get output and loss\n",
    "#         output = model(images)\n",
    "\n",
    "#         # Compute accuracy\n",
    "#         acc_val[0] += torch.sum(torch.eq(targets, torch.argmax(output, dim=1))) / len(val_loader.dataset)\n",
    "\n",
    "#         # Prepare outputs to be saved for each epoch\n",
    "#         if i == 0:\n",
    "#             epoch_output = output\n",
    "#             epoch_targets = targets\n",
    "#         else:\n",
    "#             epoch_output = torch.cat((epoch_output, output))\n",
    "#             epoch_targets = torch.cat((epoch_targets, targets))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Models from model vs human"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# choose your varioute Model \n",
    "from modelvshuman.models.pytorch.model_zoo import resnet50_swsl\n",
    "model = resnet50_swsl(\"resnet50_swsl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set filename\n",
    "filename = f\"./results/Sota/resnet50_swsl/NUM1/\"\n",
    "\n",
    "# Pre-allocate arrays for training and results (+1 for test set before model enters training)\n",
    "loss_val = torch.zeros(1).to(DEVICE)\n",
    "acc_val = torch.zeros(1).to(DEVICE)\n",
    "\n",
    "\n",
    "# Disable gradient for test dataset\n",
    "#model.eval()\n",
    "with torch.no_grad():\n",
    "\n",
    "    # Re-set seed to global seed\n",
    "    set_seeds(GLOBAL_SEED)\n",
    "\n",
    "    # Run test set\n",
    "    for i, (images, targets) in enumerate(val_loader):\n",
    "        print(i)\n",
    "\n",
    "        # Load images and targets onto GPU\n",
    "        images = images.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "\n",
    "        # Get output and loss\n",
    "        output_numpy = model.forward_batch(images)\n",
    "        output = torch.tensor(output_numpy).to(DEVICE)\n",
    "\n",
    "        # Compute accuracy\n",
    "        acc_val[0] += torch.sum(torch.eq(targets, torch.argmax(output, dim=1))) / len(val_loader.dataset)\n",
    "\n",
    "        # Prepare outputs to be saved for each epoch\n",
    "        if i == 0:\n",
    "            epoch_output = output\n",
    "            epoch_targets = targets\n",
    "        else:\n",
    "            epoch_output = torch.cat((epoch_output, output))\n",
    "            epoch_targets = torch.cat((epoch_targets, targets))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save CorNet or modelsvshuman"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = [[torch.argmax(epoch_output, dim=1), epoch_targets],acc_val]\n",
    "# Use torch save to write results and model into file\n",
    "torch.save(result, filename + \"RESULTS_EP0.txt\")\n",
    "torch.save(acc_val, filename + \"VAL_ACC.txt\")\n",
    "torch.save(model, filename + \"MODEL_EP0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
