{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2a6d168c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as T\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "import numpy as np\n",
    "import time\n",
    "import os\n",
    "\n",
    "from models import MNISTClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "adc17b68",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cutout(input):\n",
    "\n",
    "    input_s = torch.zeros_like(input)\n",
    "    lamb = np.random.uniform(0.0, 1.0)\n",
    "\n",
    "    H, W = input.shape[2:]\n",
    "    r_x = np.random.uniform(0, W)\n",
    "    r_y = np.random.uniform(0, H)\n",
    "    r_w = W * np.sqrt(1 - lamb)\n",
    "    r_h = H * np.sqrt(1 - lamb)\n",
    "    x1 = int(np.round(max(r_x - r_w / 2, 0)))\n",
    "    x2 = int(np.round(min(r_x + r_w / 2, W)))\n",
    "    y1 = int(np.round(max(r_y - r_h / 2, 0)))\n",
    "    y2 = int(np.round(min(r_y + r_h / 2, H)))\n",
    "\n",
    "    input[:, :, x1:x2, y1:y2] = input_s[:, :, x1:x2, y1:y2]\n",
    "\n",
    "    return input\n",
    "\n",
    "def random_convolution(imgs):\n",
    "    '''\n",
    "    random covolution in \"network randomization\"\n",
    "    \n",
    "    (imbs): B x (C x stack) x H x W, note: imgs should be normalized and torch tensor\n",
    "    '''\n",
    "    _device = imgs.device\n",
    "    \n",
    "    img_h, img_w = imgs.shape[2], imgs.shape[3]\n",
    "    num_stack_channel = imgs.shape[1]\n",
    "    num_batch = imgs.shape[0]\n",
    "    num_trans = num_batch\n",
    "    batch_size = int(num_batch / num_trans)\n",
    "    \n",
    "    # initialize random covolution\n",
    "    rand_conv = nn.Conv2d(3, 3, kernel_size=3, bias=False, padding=1).to(_device)\n",
    "    torch.nn.init.xavier_normal_(rand_conv.weight.data)\n",
    "    return rand_conv(imgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "261fda18",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_time(trainloader, model, use_cutout=False, use_random_conv=False):\n",
    "    start = time.time()\n",
    "    for x, y in trainloader:\n",
    "        with torch.no_grad():\n",
    "            if use_cutout:\n",
    "                x = cutout(x)\n",
    "            if use_random_conv:\n",
    "                x = random_convolution(x)\n",
    "            x = x.cuda()\n",
    "            _ = model(x)\n",
    "        pass\n",
    "    end = time.time()\n",
    "    print(f'Forward pass on one epoch: {end - start}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eda83da5",
   "metadata": {},
   "source": [
    "### No augmentations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6d70dd53-1270-4281-8178-589ee26062ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MNISTClassifier(img_size=(1,32,32)).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "28cede78",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Forward pass on one epoch: 6.671042203903198\n"
     ]
    }
   ],
   "source": [
    "mnist_transform = transforms.Compose([\n",
    "                        transforms.Resize((32, 32)),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize((0.5), (0.5))])\n",
    "\n",
    "trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=mnist_transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)\n",
    "compute_time(trainloader, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fb5b529",
   "metadata": {},
   "source": [
    "### Croping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "93a84f23",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Forward pass on one epoch: 16.983597993850708\n"
     ]
    }
   ],
   "source": [
    "mnist_transform = transforms.Compose([\n",
    "                        transforms.Resize((32, 32)),\n",
    "                        transforms.ToTensor(),\n",
    "                        T.CenterCrop(size=224),\n",
    "                        transforms.Normalize((0.5), (0.5))])\n",
    "\n",
    "trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=mnist_transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)\n",
    "compute_time(trainloader, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0724d584",
   "metadata": {},
   "source": [
    "### Rotation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "aab705f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Forward pass on one epoch: 18.11137056350708\n"
     ]
    }
   ],
   "source": [
    "mnist_transform = transforms.Compose([\n",
    "                        transforms.Resize((32, 32)),\n",
    "                        transforms.ToTensor(),\n",
    "                        T.RandomRotation(degrees=(0, 180)),\n",
    "                        transforms.Normalize((0.5), (0.5))])\n",
    "\n",
    "trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=mnist_transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)\n",
    "compute_time(trainloader, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71ce4fd3",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Cutout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "bb139ae4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Forward pass on one epoch: 6.935394763946533\n"
     ]
    }
   ],
   "source": [
    "mnist_transform = transforms.Compose([\n",
    "                        transforms.Resize((32, 32)),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize((0.5), (0.5))])\n",
    "\n",
    "trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=mnist_transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)\n",
    "compute_time(trainloader, model, use_cutout=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90b5728e-f0ae-4050-bc08-91b6444ce460",
   "metadata": {},
   "source": [
    "### CLOP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "df72167a-09b2-4e78-ba26-9768f55cbee8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Forward pass on one epoch: 6.640528202056885\n"
     ]
    }
   ],
   "source": [
    "model = MNISTClassifier(img_size=(1,32,32), regul='clop', p=0.9).cuda()\n",
    "mnist_transform = transforms.Compose([\n",
    "                        transforms.Resize((32, 32)),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize((0.5), (0.5))])\n",
    "\n",
    "trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=mnist_transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)\n",
    "compute_time(trainloader, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2108a67d-675f-459d-9083-4c49240521ba",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
