{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification on CIFAR and ImageNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import yaml\n",
    "import copy\n",
    "from pathlib import Path\n",
    "import datetime\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "import models\n",
    "import ops.trains as trains\n",
    "import ops.tests as tests\n",
    "import ops.datasets as datasets\n",
    "import ops.schedulers as schedulers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# config_path = \"configs/cifar10_general.yaml\"\n",
    "config_path = \"configs/cifar100_general.yaml\"\n",
    "# config_path = \"configs/imagenet_general.yaml\"\n",
    "\n",
    "with open(config_path) as f:\n",
    "    args = yaml.load(f, Loader=yaml.FullLoader)\n",
    "    print(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_args = copy.deepcopy(args).get(\"dataset\")\n",
    "train_args = copy.deepcopy(args).get(\"train\")\n",
    "val_args = copy.deepcopy(args).get(\"val\")\n",
    "model_args = copy.deepcopy(args).get(\"model\")\n",
    "optim_args = copy.deepcopy(args).get(\"optim\")\n",
    "env_args = copy.deepcopy(args).get(\"env\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_train, dataset_test = datasets.get_dataset(**dataset_args, download=True)\n",
    "dataset_name = dataset_args[\"name\"]\n",
    "num_classes = len(dataset_train.classes)\n",
    "\n",
    "dataset_train = DataLoader(dataset_train, \n",
    "                           shuffle=True, \n",
    "                           num_workers=train_args.get(\"num_workers\", 4), \n",
    "                           batch_size=train_args.get(\"batch_size\", 128))\n",
    "dataset_test = DataLoader(dataset_test, \n",
    "                          num_workers=val_args.get(\"num_workers\", 4), \n",
    "                          batch_size=val_args.get(\"batch_size\", 128))\n",
    "\n",
    "print(\"Train: %s, Test: %s, Classes: %s\" % (\n",
    "    len(dataset_train.dataset), \n",
    "    len(dataset_test.dataset), \n",
    "    num_classes\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# AlexNet\n",
    "# name = \"alexnet_dnn\"\n",
    "# name = \"alexnet_mcdo\"\n",
    "# name = \"alexnet_dnn_smoothing\"\n",
    "# name = \"alexnet_mcdo_smoothing\"\n",
    "\n",
    "# VGG\n",
    "# name = \"vgg_dnn_16\"\n",
    "# name = \"vgg_mcdo_16\"\n",
    "# name = \"vgg_dnn_smoothing_16\"\n",
    "# name = \"vgg_mcdo_smoothing_16\"\n",
    "\n",
    "# Preact VGG\n",
    "# name = \"prevgg_dnn_16\"\n",
    "# name = \"prevgg_mcdo_16\"\n",
    "# name = \"prevgg_dnn_smoothing_16\"\n",
    "# name = \"prevgg_mcdo_smoothing_16\"\n",
    "\n",
    "# ResNet\n",
    "name = \"resnet_dnn_18\"\n",
    "# name = \"resnet_mcdo_18\"\n",
    "# name = \"resnet_dnn_smoothing_18\"\n",
    "# name = \"resnet_mcdo_smoothing_18\"\n",
    "\n",
    "# name = \"resnet_dnn_50\"\n",
    "# name = \"resnet_mcdo_50\"\n",
    "# name = \"resnet_dnn_smoothing_50\"\n",
    "# name = \"resnet_mcdo_smoothing_50\"\n",
    "\n",
    "# Preact ResNet\n",
    "# name = \"preresnet_dnn_50\"\n",
    "# name = \"preresnet_mcdo_50\"\n",
    "# name = \"preresnet_dnn_smoothing_50\"\n",
    "# name = \"preresnet_mcdo_smoothing_50\"\n",
    "\n",
    "# ResNeXt\n",
    "# name = \"resnext_dnn_50\"\n",
    "# name = \"resnext_mcdo_50\"\n",
    "# name = \"resnext_dnn_smoothing_50\"\n",
    "# name = \"resnext_mcdo_smoothing_50\"\n",
    "\n",
    "# WideResNet\n",
    "# name = \"wideresnet_dnn_50\"\n",
    "# name = \"wideresnet_mcdo_50\"\n",
    "# name = \"wideresnet_dnn_smoothing_50\"\n",
    "# name = \"wideresnet_mcdo_smoothing_50\"\n",
    "\n",
    "# SEResNet\n",
    "# name = \"seresnet_dnn_18\"\n",
    "# name = \"seresnet_mcdo_18\"\n",
    "# name = \"seresnet_dnn_smoothing_18\"\n",
    "# name = \"seresnet_mcdo_smoothing_18\"\n",
    "\n",
    "# name = \"seresnet_dnn_50\"\n",
    "# name = \"seresnet_mcdo_50\"\n",
    "# name = \"seresnet_dnn_smoothing_50\"\n",
    "# name = \"seresnet_mcdo_smoothing_50\"\n",
    "\n",
    "# ViT\n",
    "# name = \"vit_ti\"\n",
    "# name = \"vit_s\"\n",
    "# name = \"vit_b\"\n",
    "# name = \"vit_l\"\n",
    "# name = \"vit_h\"\n",
    "\n",
    "# PiT\n",
    "# name = \"pit_ti\"\n",
    "# name = \"pit_xs\"\n",
    "# name = \"pit_s\"\n",
    "# name = \"pit_b\"\n",
    "\n",
    "# Mixer\n",
    "# name = \"mixer_ti\"\n",
    "# name = \"mixer_s\"\n",
    "# name = \"mixer_b\"\n",
    "# name = \"mixer_l\"\n",
    "# name = \"mixer_h\"\n",
    "\n",
    "vit_kwargs = {\n",
    "    \"image_size\": 32, \n",
    "    \"patch_size\": 2,\n",
    "}\n",
    "\n",
    "model = models.get_model(name, num_classes=num_classes, \n",
    "                         stem=model_args.get(\"stem\", False), **vit_kwargs)\n",
    "# models.load(model, dataset_name, uid=current_time)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Parallelize the given `moodel` by splitting the input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "name = model.name\n",
    "model = nn.DataParallel(model)\n",
    "model.name = name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define a TensorBoard writer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "current_time = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "log_dir = os.path.join(\"runs\", dataset_name, model.name, current_time)\n",
    "writer = SummaryWriter(log_dir)\n",
    "\n",
    "with open(\"%s/config.yaml\" % log_dir, \"w\") as f:\n",
    "    yaml.dump(args, f)\n",
    "with open(\"%s/model.log\" % log_dir, \"w\") as f:\n",
    "    f.write(repr(model))\n",
    "\n",
    "print(\"Create TensorBoard log dir: \", log_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "gpu = torch.cuda.is_available()\n",
    "optimizer, train_scheduler = trains.get_optimizer(model, **optim_args)\n",
    "warmup_scheduler = schedulers.WarmupScheduler(optimizer, len(dataset_train) * train_args.get(\"warmup_epochs\", 0))\n",
    "\n",
    "trains.train(model, optimizer,\n",
    "             dataset_train, dataset_test,\n",
    "             train_scheduler, warmup_scheduler,\n",
    "             train_args, val_args, gpu,\n",
    "             writer, \n",
    "             snapshot=-1, dataset_name=dataset_name, uid=current_time)  # Set `snapshot=N` to save snapshots every N epochs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models.save(model, dataset_name, current_time, optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpu = torch.cuda.is_available()\n",
    "\n",
    "model = model.cuda() if gpu else model.cpu()\n",
    "metrics_list = []\n",
    "for n_ff in [1]:\n",
    "    print(\"N: %s, \" % n_ff, end=\"\")\n",
    "    *metrics, cal_diag = tests.test(model, n_ff, dataset_test, verbose=False, gpu=gpu)\n",
    "    metrics_list.append([n_ff, *metrics])\n",
    "\n",
    "leaderboard_path = os.path.join(\"leaderboard\", \"logs\", dataset_name, model.name)\n",
    "Path(leaderboard_path).mkdir(parents=True, exist_ok=True)\n",
    "metrics_dir = os.path.join(leaderboard_path, \"%s_%s_%s.csv\" % (dataset_name, model.name, current_time))\n",
    "tests.save_metrics(metrics_dir, metrics_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpu = torch.cuda.is_available()\n",
    "\n",
    "model = model.cuda() if gpu else model.cpu()\n",
    "metrics_list = []\n",
    "for n_ff in [1, 2, 3, 4, 5, 10, 20, 50]:\n",
    "    print(\"N: %s, \" % n_ff, end=\"\")\n",
    "    *metrics, cal_diag = tests.test(model, n_ff, dataset_test, verbose=False, gpu=gpu)\n",
    "    metrics_list.append([n_ff, *metrics])\n",
    "\n",
    "leaderboard_path = os.path.join(\"leaderboard\", \"logs\", dataset_name, model.name)\n",
    "Path(leaderboard_path).mkdir(parents=True, exist_ok=True)\n",
    "metrics_dir = os.path.join(leaderboard_path, \"%s_%s_%s.csv\" % (dataset_name, model.name, current_time))\n",
    "tests.save_metrics(metrics_dir, metrics_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
