{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "83b1c45c",
   "metadata": {},
   "source": [
    "# Robustness on CIFAR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bade143",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import json\n",
    "import copy\n",
    "from pathlib import Path\n",
    "import datetime\n",
    "import csv\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "import models\n",
    "import ops.trains as trains\n",
    "import ops.tests as tests\n",
    "import ops.datasets as datasets\n",
    "import ops.schedulers as schedulers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c6635ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# config_path = \"configs/cifar10_general.json\"\n",
    "config_path = \"configs/cifar100_general.json\"\n",
    "# config_path = \"configs/imagenet_general.json\"\n",
    "\n",
    "with open(config_path) as f:\n",
    "    args = json.load(f)\n",
    "    print(\"args: \\n\", args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93447458",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_args = copy.deepcopy(args).get(\"dataset\")\n",
    "train_args = copy.deepcopy(args).get(\"train\")\n",
    "val_args = copy.deepcopy(args).get(\"val\")\n",
    "model_args = copy.deepcopy(args).get(\"model\")\n",
    "optim_args = copy.deepcopy(args).get(\"optim\")\n",
    "env_args = copy.deepcopy(args).get(\"env\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc0b8f1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_train, dataset_test = datasets.get_dataset(**dataset_args, download=True)\n",
    "dataset_name = dataset_args[\"name\"]\n",
    "num_classes = len(dataset_train.classes)\n",
    "\n",
    "dataset_train = DataLoader(dataset_train, \n",
    "                           shuffle=True, \n",
    "                           num_workers=train_args.get(\"num_workers\", 4), \n",
    "                           batch_size=train_args.get(\"batch_size\", 128))\n",
    "dataset_test = DataLoader(dataset_test, \n",
    "                          num_workers=val_args.get(\"num_workers\", 4), \n",
    "                          batch_size=val_args.get(\"batch_size\", 128))\n",
    "\n",
    "print(\"Train: %s, Test: %s, Classes: %s\" % (\n",
    "    len(dataset_train.dataset), \n",
    "    len(dataset_test.dataset), \n",
    "    num_classes\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b1b5cf3",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a4d74ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# AlexNet\n",
    "# name = \"alexnet_dnn\"\n",
    "# name = \"alexnet_dnn_smoothing\"\n",
    "# name = \"alexnet_mcdo\"\n",
    "# name = \"alexnet_mcdo_smoothing\"\n",
    "\n",
    "# VGG\n",
    "# name = \"vgg_dnn_19\"\n",
    "# name = \"vgg_dnn_smoothing_19\"\n",
    "# name = \"vgg_mcdo_19\"\n",
    "# name = \"vgg_mcdo_smoothing_19\"\n",
    "\n",
    "# ResNet\n",
    "name = \"resnet_dnn_18\"\n",
    "# name = \"resnet_dnn_smoothing_18\"\n",
    "# name = \"resnet_mcdo_18\"\n",
    "# name = \"resnet_mcdo_smoothing_18\"\n",
    "\n",
    "# name = \"resnet_dnn_50\"\n",
    "# name = \"resnet_mcdo_50\"\n",
    "# name = \"resnet_dnn_smoothing_50\"\n",
    "# name = \"resnet_mcdo_smoothing_50\"\n",
    "\n",
    "# Preact ResNet\n",
    "# name = \"preresnet_dnn_50\"\n",
    "# name = \"preresnet_mcdo_50\"\n",
    "# name = \"preresnet_dnn_smoothing_50\"\n",
    "# name = \"preresnet_mcdo_smoothing_50\"\n",
    "\n",
    "# ResNeXt\n",
    "# name = \"resnext_dnn_50\"\n",
    "# name = \"resnext_mcdo_50\"\n",
    "# name = \"resnext_dnn_smoothing_50\"\n",
    "# name = \"resnext_mcdo_smoothing_50\"\n",
    "\n",
    "# WideResNet\n",
    "# name = \"wideresnet_dnn_50\"\n",
    "# name = \"wideresnet_mcdo_50\"\n",
    "# name = \"wideresnet_dnn_smoothing_50\"\n",
    "# name = \"wideresnet_mcdo_smoothing_50\"\n",
    "\n",
    "\n",
    "uid = \"\"  # Model UID required\n",
    "model = models.get_model(name, num_classes=num_classes, \n",
    "                         stem=model_args.get(\"stem\", False))\n",
    "models.load(model, dataset_name, uid)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3dc83468",
   "metadata": {},
   "source": [
    "Parallelize the given `moodel` by splitting the input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0ee6c23",
   "metadata": {},
   "outputs": [],
   "source": [
    "name = model.name\n",
    "model = nn.DataParallel(model)\n",
    "model.name = name"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cde2a379",
   "metadata": {},
   "source": [
    "Test model performance on in-domain data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7943b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "gpu = torch.cuda.is_available()\n",
    "\n",
    "model = model.cuda() if gpu else model.cpu()\n",
    "metrics_list = []\n",
    "for n_ff in [1]:\n",
    "    print(\"N: %s, \" % n_ff, end=\"\")\n",
    "    *metrics, cal_diag = tests.test(model, n_ff, dataset_test, verbose=False, gpu=gpu)\n",
    "    metrics_list.append([n_ff, *metrics])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fda55c4e",
   "metadata": {},
   "source": [
    "## Corruption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95bccfc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_ff = 1\n",
    "\n",
    "gpu = torch.cuda.is_available()\n",
    "val_args = copy.deepcopy(args).get(\"val\")\n",
    "dataset_args = copy.deepcopy(args).get(\"dataset\")\n",
    "\n",
    "model = model.cuda() if gpu else model.cpu()\n",
    "metrics_c = { intensity: {} for intensity in range(1, 6) }\n",
    "for intensity in range(1, 6):\n",
    "    for ctype in datasets.get_corruptions():\n",
    "        dataset_c = datasets.get_dataset_c(**dataset_args, ctype=ctype, intensity=intensity, download=True)\n",
    "        dataset_c = DataLoader(dataset_c, \n",
    "                               num_workers=val_args.get(\"num_workers\", 4), \n",
    "                               batch_size=val_args.get(\"batch_size\", 128))\n",
    "        print(\"Corruption type: %s, Intensity: %d, \" % (ctype, intensity), end=\"\")\n",
    "        *metrics, cal_diag = tests.test(model, n_ff, dataset_c, verbose=False, gpu=gpu)\n",
    "        metrics_c[intensity][ctype] = metrics\n",
    "\n",
    "leaderboard_path = os.path.join(\"leaderboard\", \"logs\", dataset_name, model.name)\n",
    "Path(leaderboard_path).mkdir(parents=True, exist_ok=True)\n",
    "metrics_dir = os.path.join(leaderboard_path, \"%s_%s_%s_%s_corrupted.csv\" % (dataset_name, model.name, uid, n_ff))\n",
    "metrics_c_list = [[i, typ, *metrics] for i, typ_metrics in metrics_c.items() for typ, metrics in typ_metrics.items()]\n",
    "tests.save_metrics(metrics_dir, metrics_c_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bae51b48",
   "metadata": {},
   "source": [
    "## Perturbation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac255912",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_ff = 1\n",
    "\n",
    "gpu = torch.cuda.is_available()\n",
    "val_args = copy.deepcopy(args).get(\"val\")\n",
    "dataset_args = copy.deepcopy(args).get(\"dataset\")\n",
    "\n",
    "metrics_p = {}\n",
    "for ptype in datasets.get_perturbations():\n",
    "    dataset_p = datasets.get_cifar10p(ptype=ptype,\n",
    "                                      root=\"../data\",\n",
    "                                      base_folder=\"cifar-10-p\")\n",
    "    dataset_p = DataLoader(dataset_p, \n",
    "                           num_workers=val_args.get(\"num_workers\", 4), \n",
    "                           batch_size=val_args.get(\"batch_size\", 128))\n",
    "    metrics = tests.test_perturbation(dataset_p, model, n_ff=n_ff)\n",
    "    print(\"Perturbation type: %s, Consistency: %.5f, CEC: %.5f\" % \n",
    "          (ptype, metrics[0], metrics[1]))\n",
    "    metrics_p[ptype] = metrics\n",
    "\n",
    "leaderboard_path = os.path.join(\"leaderboard\", \"logs\", dataset_name, model.name)\n",
    "Path(leaderboard_path).mkdir(parents=True, exist_ok=True)\n",
    "metrics_dir = os.path.join(leaderboard_path, \"%s_%s_%s_%s_perturbated.csv\" % (dataset_name, model.name, uid, n_ff))\n",
    "metrics_p_list = [[typ, *metrics] for typ, metrics in metrics_p.items()]\n",
    "tests.save_lists(metrics_dir, metrics_p_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fe920ab",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
