{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train Packed Ensembles of ResNet50 on CIFAR-10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import warnings\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from torchvision.datasets import CIFAR10, SVHN\n",
    "import torchvision.transforms as T\n",
    "\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.loggers.tensorboard import TensorBoardLogger\n",
    "from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint\n",
    "from pytorch_lightning.callbacks import RichProgressBar\n",
    "\n",
    "from pysemble.lightning_module import PackedEns\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# Set seed to 0\n",
    "pl.seed_everything(0, workers=True);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ROOT = Path(\"./data/\")\n",
    "N_SUBNETS = 4\n",
    "ALPHA = 2\n",
    "GAMMA = 2\n",
    "\n",
    "MAX_EPOCHS = 200\n",
    "BATCH_SIZE = 128\n",
    "NUM_WORKERS = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CIFAR-10 Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform_train = T.Compose(\n",
    "    [\n",
    "        T.RandomCrop(32, padding=4),\n",
    "        T.RandomHorizontalFlip(),\n",
    "        T.ToTensor(),\n",
    "        T.Normalize(\n",
    "            (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010),\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "\n",
    "transform_test = T.Compose(\n",
    "    [\n",
    "        T.ToTensor(),\n",
    "        T.Normalize(\n",
    "            (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010),\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "\n",
    "cifar10_train_dataset = CIFAR10(root=ROOT, train=True, download=True, transform=transform_train)\n",
    "cifar10_test_dataset = CIFAR10(root=ROOT, train=False, transform=transform_test)\n",
    "svhn_test_dataset = SVHN(root=ROOT, split=\"test\", download=True, transform=transform_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar10_train_loader = DataLoader(\n",
    "    cifar10_train_dataset,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=True,\n",
    "    num_workers=NUM_WORKERS,\n",
    "    pin_memory=True,\n",
    "    persistent_workers=True,\n",
    ")\n",
    "cifar10_test_loader = DataLoader(\n",
    "    cifar10_test_dataset,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=False,\n",
    "    num_workers=NUM_WORKERS,\n",
    "    pin_memory=True,\n",
    "    persistent_workers=True,\n",
    ")\n",
    "svhn_test_loader = DataLoader(\n",
    "    svhn_test_dataset,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=False,\n",
    "    num_workers=NUM_WORKERS,\n",
    "    pin_memory=True,\n",
    "    persistent_workers=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lightning Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = PackedEns(N_SUBNETS, ALPHA, GAMMA, num_classes=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lightning Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tb_logger = TensorBoardLogger(\n",
    "    str(ROOT / \"logs\"),\n",
    "    name=\"packed_ens\",\n",
    "    default_hp_metric=False,\n",
    ")\n",
    "\n",
    "best_checkpoint = ModelCheckpoint(\n",
    "    monitor=\"hp/val_acc\", mode=\"max\", save_weights_only=False,\n",
    ")\n",
    "\n",
    "progress_bar = RichProgressBar(refresh_rate=10)\n",
    "\n",
    "callbacks = [best_checkpoint, progress_bar]\n",
    "\n",
    "trainer = pl.Trainer(\n",
    "    logger=tb_logger,\n",
    "    callbacks=callbacks,\n",
    "    max_epochs=MAX_EPOCHS,\n",
    "    deterministic=True,\n",
    "    accelerator=\"gpu\",\n",
    "    devices=1,\n",
    "    precision=16\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.fit(model, train_dataloaders=cifar10_train_loader, val_dataloaders=cifar10_test_loader);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.test(dataloaders=[cifar10_test_loader, svhn_test_loader], ckpt_path=\"best\");"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('deep')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f3530934b33069a982b3b8110fefe3d4a650a9a34208896fd915fb81f91185ec"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
