{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2a6d168c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as T\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "from models import VGG11\n",
    "\n",
    "import numpy as np\n",
    "import time\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3da4edb3-19f0-4be3-821b-cdf097f2e167",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cutout(input):\n",
    "\n",
    "    input_s = torch.zeros_like(input)\n",
    "    lamb = np.random.uniform(0.0, 1.0)\n",
    "\n",
    "    H, W = input.shape[2:]\n",
    "    r_x = np.random.uniform(0, W)\n",
    "    r_y = np.random.uniform(0, H)\n",
    "    r_w = W * np.sqrt(1 - lamb)\n",
    "    r_h = H * np.sqrt(1 - lamb)\n",
    "    x1 = int(np.round(max(r_x - r_w / 2, 0)))\n",
    "    x2 = int(np.round(min(r_x + r_w / 2, W)))\n",
    "    y1 = int(np.round(max(r_y - r_h / 2, 0)))\n",
    "    y2 = int(np.round(min(r_y + r_h / 2, H)))\n",
    "\n",
    "    input[:, :, x1:x2, y1:y2] = input_s[:, :, x1:x2, y1:y2]\n",
    "\n",
    "    return input\n",
    "\n",
    "def random_convolution(imgs):\n",
    "    '''\n",
    "    random covolution in \"network randomization\"\n",
    "    \n",
    "    (imbs): B x (C x stack) x H x W, note: imgs should be normalized and torch tensor\n",
    "    '''\n",
    "    _device = imgs.device\n",
    "    \n",
    "    img_h, img_w = imgs.shape[2], imgs.shape[3]\n",
    "    num_stack_channel = imgs.shape[1]\n",
    "    num_batch = imgs.shape[0]\n",
    "    num_trans = num_batch\n",
    "    batch_size = int(num_batch / num_trans)\n",
    "    \n",
    "    # initialize random covolution\n",
    "    rand_conv = nn.Conv2d(3, 3, kernel_size=3, bias=False, padding=1).to(_device)\n",
    "    torch.nn.init.xavier_normal_(rand_conv.weight.data)\n",
    "    return rand_conv(imgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ca96aead-856f-4640-aa90-7273f09a4a15",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_time(trainloader, model, use_cutout=False, use_random_conv=False):\n",
    "    start = time.time()\n",
    "    for x, y in trainloader:\n",
    "        with torch.no_grad():\n",
    "            if use_cutout:\n",
    "                x = cutout(x)\n",
    "            if use_random_conv:\n",
    "                x = random_convolution(x)\n",
    "            x = x.cuda()\n",
    "            _ = model(x)\n",
    "        pass\n",
    "    end = time.time()\n",
    "    print(f'Forward pass on one epoch: {end - start}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eda83da5",
   "metadata": {},
   "source": [
    "### No augmentations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "409d6888",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = VGG11().cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f4146312",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Forward pass on one epoch: 41.66870403289795\n"
     ]
    }
   ],
   "source": [
    "means, stds = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)\n",
    "transform = transforms.Compose([\n",
    "            transforms.Resize((224, 224)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(means, stds),\n",
    "        ])\n",
    "trainset = torchvision.datasets.ImageFolder(os.path.join('data/imagenette2', \"train\"), transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)\n",
    "\n",
    "compute_time(trainloader, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fb5b529",
   "metadata": {},
   "source": [
    "### Croping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "93a84f23",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Forward pass on one epoch: 41.71815466880798\n"
     ]
    }
   ],
   "source": [
    "means, stds = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)\n",
    "transform = transforms.Compose([\n",
    "            transforms.Resize((224, 224)),\n",
    "            transforms.ToTensor(),\n",
    "            T.CenterCrop(size=224),\n",
    "            transforms.Normalize(means, stds),\n",
    "        ])\n",
    "trainset = torchvision.datasets.ImageFolder(os.path.join('data/imagenette2', \"train\"), transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)\n",
    "\n",
    "compute_time(trainloader, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0724d584",
   "metadata": {},
   "source": [
    "### Rotation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "aab705f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Forward pass on one epoch: 54.179277420043945\n"
     ]
    }
   ],
   "source": [
    "means, stds = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)\n",
    "transform = transforms.Compose([\n",
    "            transforms.Resize((224, 224)),\n",
    "            transforms.ToTensor(),\n",
    "            T.RandomRotation(degrees=(0, 180)),\n",
    "            transforms.Normalize(means, stds),\n",
    "        ])\n",
    "trainset = torchvision.datasets.ImageFolder(os.path.join('data/imagenette2', \"train\"), transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)\n",
    "\n",
    "compute_time(trainloader, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71ce4fd3",
   "metadata": {},
   "source": [
    "### Cutout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bb139ae4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Forward pass on one epoch: 41.68955063819885\n"
     ]
    }
   ],
   "source": [
    "means, stds = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)\n",
    "transform = transforms.Compose([\n",
    "            transforms.Resize((224, 224)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(means, stds),\n",
    "        ])\n",
    "trainset = torchvision.datasets.ImageFolder(os.path.join('data/imagenette2', \"train\"), transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)\n",
    "\n",
    "compute_time(trainloader, model, use_cutout=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d790ed74-9f05-4252-bf51-b2922a81cd63",
   "metadata": {},
   "source": [
    "### Random conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "574e8320-267c-4063-872a-550e1071f001",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Forward pass on one epoch: 42.38036584854126\n"
     ]
    }
   ],
   "source": [
    "means, stds = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)\n",
    "transform = transforms.Compose([\n",
    "            transforms.Resize((224, 224)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(means, stds),\n",
    "        ])\n",
    "trainset = torchvision.datasets.ImageFolder(os.path.join('data/imagenette2', \"train\"), transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)\n",
    "\n",
    "compute_time(trainloader, model, use_random_conv=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90b5728e-f0ae-4050-bc08-91b6444ce460",
   "metadata": {},
   "source": [
    "### CLOP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "88731772-d37a-4997-a9f8-89aca7ee3b91",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Forward pass on one epoch: 41.68545413017273\n"
     ]
    }
   ],
   "source": [
    "model = VGG11(regul='clop', p=0.6).cuda()\n",
    "means, stds = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)\n",
    "transform = transforms.Compose([\n",
    "            transforms.Resize((224, 224)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(means, stds),\n",
    "        ])\n",
    "trainset = torchvision.datasets.ImageFolder(os.path.join('data/imagenette2', \"train\"), transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)\n",
    "\n",
    "compute_time(trainloader, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c48c69f-99d6-4d6a-a9a9-50bc9a9dca71",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
