{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f54e6039",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=1\n"
     ]
    }
   ],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1ce145d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as t\n",
    "from torch import nn\n",
    "from torchvision import models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "95395586",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.autograd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ab3ed4cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyReLUFn(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, x):\n",
    "        res = nn.functional.relu(x)\n",
    "        ctx.save_for_backward(res)\n",
    "        return res\n",
    "    \n",
    "    @staticmethod\n",
    "    def backward(ctx, grad):\n",
    "        x, = ctx.saved_tensors\n",
    "        return grad * (x > 0)\n",
    "\n",
    "\n",
    "class MyReLU(nn.Module):\n",
    "    def forward(self, x):\n",
    "        return MyReLUFn.apply(x)\n",
    "    \n",
    "    \n",
    "resnet = models.resnet18()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "573c6c12",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResNet(\n",
       "  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (relu): ReLU(inplace=True)\n",
       "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  (layer1): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (layer2): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (layer3): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (layer4): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "  (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "models.resnet18()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "65bc84bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "344.58447265625\n",
      "344.58447265625\n",
      "344.58447265625\n"
     ]
    }
   ],
   "source": [
    "BATCH_SZ = 16\n",
    "C = 3\n",
    "H = 224\n",
    "W = 224\n",
    "\n",
    "\n",
    "def run_through(model):\n",
    "    return model(t.randn(BATCH_SZ, C, H, W, requires_grad=True).cuda())\n",
    "\n",
    "\n",
    "def print_mem(offset: int = 0):\n",
    "    print((t.cuda.memory_allocated() - offset) / 2**20)\n",
    "\n",
    "\n",
    "def test_memory(model_getter):\n",
    "    model = model_getter()\n",
    "    \n",
    "    start_mem = t.cuda.memory_allocated()\n",
    "    loss = run_through(model)\n",
    "    print_mem(start_mem)\n",
    "    \n",
    "    \n",
    "def standard_resnet():\n",
    "    return models.resnet18().cuda()\n",
    "\n",
    "    \n",
    "def my_resnet():\n",
    "    resnet = standard_resnet()\n",
    "    \n",
    "    resnet.relu = MyReLU()\n",
    "    for layer in [resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4]:\n",
    "        for block in layer:\n",
    "            block.relu = MyReLU()\n",
    "    \n",
    "    return resnet    \n",
    "    \n",
    "    \n",
    "def noninplace_resnet():\n",
    "    resnet = standard_resnet()\n",
    "    \n",
    "    resnet.relu = nn.ReLU(inplace=False)\n",
    "    for layer in [resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4]:\n",
    "        for block in layer:\n",
    "            block.relu = nn.ReLU(inplace=False)\n",
    "    \n",
    "    return resnet    \n",
    "    \n",
    "test_memory(standard_resnet)\n",
    "test_memory(my_resnet)\n",
    "test_memory(noninplace_resnet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41c452fb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
