{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import resnet50, ResNet50_Weights, swin_v2_b, Swin_V2_B_Weights, resnet101, ResNet101_Weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: \"https://download.pytorch.org/models/resnet101-cd907fc2.pth\" to /home/tarun/.cache/torch/hub/checkpoints/resnet101-cd907fc2.pth\n",
      "100%|██████████| 171M/171M [00:15<00:00, 11.7MB/s] \n"
     ]
    }
   ],
   "source": [
    "res_model = resnet101(weights=ResNet101_Weights.IMAGENET1K_V2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResNet(\n",
       "  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (relu): ReLU(inplace=True)\n",
       "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  (layer1): Sequential(\n",
       "    (0): Bottleneck(\n",
       "      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): Bottleneck(\n",
       "      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (2): Bottleneck(\n",
       "      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer2): Sequential(\n",
       "    (0): Bottleneck(\n",
       "      (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): Bottleneck(\n",
       "      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (2): Bottleneck(\n",
       "      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (3): Bottleneck(\n",
       "      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer3): Sequential(\n",
       "    (0): Bottleneck(\n",
       "      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (2): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (3): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (4): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (5): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (6): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (7): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (8): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (9): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (10): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (11): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (12): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (13): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (14): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (15): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (16): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (17): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (18): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (19): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (20): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (21): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (22): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer4): Sequential(\n",
       "    (0): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): Bottleneck(\n",
       "      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "    (2): Bottleneck(\n",
       "      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "  (fc): Linear(in_features=2048, out_features=1000, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/tarun/miniconda3/envs/neel/lib/python3.11/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=Swin_V2_B_Weights.IMAGENET1K_V1`. You can also use `weights=Swin_V2_B_Weights.DEFAULT` to get the most up-to-date weights.\n",
      "  warnings.warn(msg)\n"
     ]
    }
   ],
   "source": [
    "swin_model = swin_v2_b(weights=Swin_V2_B_Weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SwinTransformer(\n",
       "  (features): Sequential(\n",
       "    (0): Sequential(\n",
       "      (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))\n",
       "      (1): Permute()\n",
       "      (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (1): Sequential(\n",
       "      (0): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=128, out_features=384, bias=True)\n",
       "          (proj): Linear(in_features=128, out_features=128, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=4, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.0, mode=row)\n",
       "        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=128, out_features=512, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=512, out_features=128, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (1): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=128, out_features=384, bias=True)\n",
       "          (proj): Linear(in_features=128, out_features=128, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=4, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.021739130434782608, mode=row)\n",
       "        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=128, out_features=512, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=512, out_features=128, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (2): PatchMergingV2(\n",
       "      (reduction): Linear(in_features=512, out_features=256, bias=False)\n",
       "      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (3): Sequential(\n",
       "      (0): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=256, out_features=768, bias=True)\n",
       "          (proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=8, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.043478260869565216, mode=row)\n",
       "        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=256, out_features=1024, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=1024, out_features=256, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (1): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=256, out_features=768, bias=True)\n",
       "          (proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=8, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.06521739130434782, mode=row)\n",
       "        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=256, out_features=1024, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=1024, out_features=256, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (4): PatchMergingV2(\n",
       "      (reduction): Linear(in_features=1024, out_features=512, bias=False)\n",
       "      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (5): Sequential(\n",
       "      (0): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.08695652173913043, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (1): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.10869565217391304, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (2): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.13043478260869565, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (3): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.15217391304347827, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (4): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.17391304347826086, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (5): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.1956521739130435, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (6): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.21739130434782608, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (7): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.2391304347826087, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (8): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.2608695652173913, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (9): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.2826086956521739, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (10): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.30434782608695654, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (11): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.32608695652173914, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (12): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.34782608695652173, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (13): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.3695652173913043, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (14): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.391304347826087, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (15): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.41304347826086957, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (16): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.43478260869565216, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (17): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "          (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.45652173913043476, mode=row)\n",
       "        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (6): PatchMergingV2(\n",
       "      (reduction): Linear(in_features=2048, out_features=1024, bias=False)\n",
       "      (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (7): Sequential(\n",
       "      (0): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=32, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.4782608695652174, mode=row)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (1): SwinTransformerBlockV2(\n",
       "        (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ShiftedWindowAttentionV2(\n",
       "          (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "          (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "          (cpb_mlp): Sequential(\n",
       "            (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Linear(in_features=512, out_features=32, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (stochastic_depth): StochasticDepth(p=0.5, mode=row)\n",
       "        (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MLP(\n",
       "          (0): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "          (2): Dropout(p=0.0, inplace=False)\n",
       "          (3): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (4): Dropout(p=0.0, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "  (permute): Permute()\n",
       "  (avgpool): AdaptiveAvgPool2d(output_size=1)\n",
       "  (flatten): Flatten(start_dim=1, end_dim=-1)\n",
       "  (head): Linear(in_features=1024, out_features=1000, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "swin_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResNet50Classifier(nn.Module):\n",
    "    def __init__(self, num_labels, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)\n",
    "        self.repr_model = nn.Sequential(*list(model.children())[:-1], nn.Flatten())\n",
    "        self.classifier = nn.Linear(in_features=2048, out_features=num_labels, bias=True)\n",
    "        if num_labels == 1000:\n",
    "            self.classifier = model.fc\n",
    "\n",
    "    def forward(self, x):\n",
    "        repr = self.repr_model(x)\n",
    "        return self.classifier(repr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ResNet50Classifier(num_labels=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResNet50Classifier(\n",
       "  (repr_model): Sequential(\n",
       "    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU(inplace=True)\n",
       "    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "    (4): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (5): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (3): Bottleneck(\n",
       "        (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (6): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (3): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (4): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (5): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (7): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (8): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "    (9): Flatten(start_dim=1, end_dim=-1)\n",
       "  )\n",
       "  (classifier): Linear(in_features=2048, out_features=1000, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 3, 224, 224])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.randn((1, 3, 224, 224)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "rand_inp = torch.randn((1, 3, 224, 224))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-5.8660e-02, -2.1562e-01,  2.3406e-02, -3.3462e-02,  3.9180e-02,\n",
       "         -8.5186e-02, -1.1463e-04,  7.9720e-03,  1.4957e-01, -1.0326e-01,\n",
       "         -2.1049e-01, -1.6315e-02,  5.5162e-02, -1.1860e-01, -6.3085e-02,\n",
       "         -2.7938e-02,  3.4152e-02,  6.3728e-02,  3.6285e-02, -3.0951e-02,\n",
       "         -8.7539e-02,  2.1712e-02,  3.0056e-02,  3.4307e-02,  1.0682e-01,\n",
       "         -1.5340e-01, -1.5469e-01, -1.1974e-01, -2.1853e-01, -5.6144e-02,\n",
       "         -1.3605e-01, -1.5758e-01, -2.1487e-01, -1.0316e-01, -5.9025e-02,\n",
       "          3.3612e-03, -7.9679e-03, -7.6521e-02, -9.4139e-02,  1.3740e-02,\n",
       "         -2.9567e-01, -2.7352e-01, -4.7620e-02, -1.7323e-01, -1.9008e-01,\n",
       "         -2.1079e-01, -3.9792e-02, -2.6766e-01,  1.3814e-02, -1.0559e-01,\n",
       "         -5.4808e-03,  1.2212e-01, -2.3230e-01, -1.2079e-01, -2.3697e-01,\n",
       "         -2.6356e-01, -2.1494e-01, -8.0847e-02,  1.5610e-02, -9.1518e-02,\n",
       "         -1.0659e-01, -1.2309e-01, -7.3370e-02, -1.7430e-01, -3.1835e-02,\n",
       "         -2.3297e-01, -7.5843e-02, -1.7234e-01, -1.7088e-01, -1.0650e-01,\n",
       "         -1.2633e-01, -2.2436e-01, -2.5948e-01,  1.4899e-03, -1.3207e-01,\n",
       "         -1.0716e-01,  5.9147e-02, -3.9363e-02, -6.7278e-03, -1.4659e-01,\n",
       "         -7.3322e-03,  1.1680e-01, -1.1496e-02, -2.6054e-01, -6.1695e-02,\n",
       "          5.2757e-02, -6.4250e-02,  3.9342e-02, -1.2961e-02,  2.4374e-02,\n",
       "         -1.1947e-01,  1.6585e-01, -1.1482e-01,  1.1079e-01,  2.1573e-02,\n",
       "         -9.6015e-02,  1.0226e-01, -9.0010e-02, -1.3473e-02, -7.0562e-02,\n",
       "         -8.0098e-02,  1.4754e-01, -1.7342e-01,  1.7592e-01,  1.4815e-01,\n",
       "          3.9071e-01,  2.5337e-01, -1.3872e-01, -2.2159e-01, -3.1420e-01,\n",
       "         -3.9102e-01, -1.7172e-01,  1.1301e-01, -9.7920e-02, -8.3593e-03,\n",
       "         -2.3097e-01, -2.3011e-01, -2.3369e-01, -1.0459e-01, -1.0186e-02,\n",
       "         -3.5810e-01, -4.5051e-02, -1.7585e-01, -2.1008e-01, -1.8887e-01,\n",
       "         -3.0079e-01,  2.2413e-02,  4.7506e-02, -3.3614e-02,  9.0848e-02,\n",
       "         -2.2051e-01, -4.8445e-03,  1.0666e-01, -5.1658e-02, -1.8597e-02,\n",
       "          6.7795e-02, -1.3586e-01, -1.0246e-01, -6.0281e-02, -1.6838e-01,\n",
       "         -1.7915e-01, -1.1169e-01, -1.4231e-01, -6.1034e-02, -6.7012e-02,\n",
       "         -2.0487e-02, -2.2684e-02,  1.6808e-01,  1.2276e-01,  1.4591e-01,\n",
       "          9.2301e-03,  2.6103e-01, -1.1816e-01,  2.7369e-01,  2.0906e-01,\n",
       "          1.8619e-01,  2.2127e-01,  2.1149e-01, -1.8778e-01,  2.2451e-01,\n",
       "          2.5008e-01,  1.8428e-01,  4.5649e-02,  1.2698e-01,  9.3232e-02,\n",
       "         -2.2601e-01, -1.7165e-02, -2.1441e-01,  3.5009e-01,  3.3781e-01,\n",
       "          4.9500e-01,  1.2219e-01,  1.1443e-01,  4.3969e-01,  3.4117e-01,\n",
       "          8.9539e-02, -1.3635e-03,  6.5815e-02,  2.9517e-01,  8.6058e-02,\n",
       "          2.1255e-01,  2.8019e-01,  2.9538e-01,  8.0565e-02,  5.3600e-01,\n",
       "          2.4076e-01,  3.7005e-01,  1.1593e-01,  2.8775e-01,  2.0053e-01,\n",
       "          2.6629e-01,  2.9318e-01,  3.2647e-01, -1.2192e-01,  2.2759e-01,\n",
       "         -3.4552e-02,  6.5858e-02,  1.2184e-01, -1.1611e-01,  1.7002e-01,\n",
       "          2.2899e-01,  1.5947e-01,  4.9881e-01,  6.2806e-01,  3.1204e-01,\n",
       "         -8.7780e-02,  1.2163e-01,  4.1696e-01,  3.6783e-01,  3.6908e-01,\n",
       "         -5.8080e-03,  3.2986e-01,  1.1076e-01,  1.7374e-01,  4.6612e-02,\n",
       "          1.6096e-01,  4.9993e-01,  6.4068e-03,  3.6882e-01,  1.7274e-01,\n",
       "          3.5628e-02, -7.3599e-03,  7.3059e-01,  1.9854e-01,  5.5261e-02,\n",
       "          1.1354e-01,  9.7396e-02, -2.4459e-01,  4.8655e-01,  6.3419e-01,\n",
       "          1.7407e-01,  1.3035e-01, -8.0233e-02,  2.0570e-01,  2.0051e-01,\n",
       "          1.3842e-01,  1.2668e-01,  2.8960e-01, -4.9516e-02,  9.5622e-02,\n",
       "          3.9276e-02,  8.1422e-02,  1.9139e-01,  3.2768e-01,  1.7151e-01,\n",
       "          2.2235e-01,  1.7143e-01,  3.3260e-01,  2.4008e-01,  4.3723e-01,\n",
       "          3.6052e-01, -1.4622e-02, -3.2805e-01,  2.0522e-01,  2.7127e-01,\n",
       "          6.7445e-02, -6.2343e-02,  4.6008e-01,  4.7580e-01,  3.9206e-01,\n",
       "          3.7176e-01,  3.8617e-01,  3.2434e-02,  3.5646e-01,  9.5142e-02,\n",
       "          2.9630e-01,  2.8891e-01,  3.1543e-01, -1.8976e-01,  4.0708e-01,\n",
       "          5.2407e-01,  3.4148e-01,  2.8393e-01,  3.4582e-01,  3.4763e-01,\n",
       "          1.0872e-01,  2.4034e-01,  4.6310e-01,  1.9363e-01,  6.1256e-01,\n",
       "          2.0091e-01,  3.7698e-01,  4.2466e-01,  3.3629e-01,  5.3141e-01,\n",
       "          3.2738e-01,  2.8749e-01,  4.1361e-01, -1.1674e-01,  1.4269e-01,\n",
       "         -2.1067e-01,  2.2781e-01, -2.2370e-02, -1.2583e-01,  3.7178e-01,\n",
       "          2.8018e-01,  7.3633e-01,  2.1212e-01,  3.3912e-01,  3.3061e-01,\n",
       "         -3.6240e-01, -2.4994e-01, -2.1300e-01, -1.8365e-01, -3.3815e-01,\n",
       "         -1.3740e-01,  8.9831e-02, -2.7666e-01, -2.8130e-01, -1.7547e-01,\n",
       "         -1.7428e-01, -2.8445e-01, -2.8936e-01, -1.6512e-01, -1.0570e-01,\n",
       "         -1.5370e-01,  1.1436e-03, -2.3495e-01, -2.5799e-01, -2.0066e-01,\n",
       "         -3.2411e-01, -2.0960e-01, -2.8850e-01, -4.0092e-01, -7.4763e-02,\n",
       "         -2.8938e-01, -3.2552e-01, -1.1810e-01, -4.2431e-02, -2.0727e-01,\n",
       "          5.6929e-02,  1.0431e-01,  1.3905e-01,  1.8015e-01,  1.8536e-01,\n",
       "          3.3983e-01,  2.8186e-01,  1.9269e-01,  3.2566e-01,  3.4010e-02,\n",
       "         -2.3496e-01,  1.5395e-01,  8.7668e-02, -5.2797e-02,  6.4299e-02,\n",
       "          5.0769e-01,  4.1322e-01,  7.1562e-02,  6.0935e-02,  7.6061e-02,\n",
       "          1.0319e-01,  2.1294e-01,  1.5302e-01,  1.2171e-01,  1.4456e-01,\n",
       "          3.5053e-02,  6.6327e-01,  3.4303e-01,  5.8247e-01,  7.0268e-01,\n",
       "          1.7966e-01,  1.2798e-01,  2.7284e-01, -4.3191e-02,  4.7465e-01,\n",
       "          1.5055e-01,  1.6825e-01,  1.4819e-01,  3.9967e-01,  9.4724e-04,\n",
       "          1.7639e-01,  3.2570e-01,  2.5935e-01,  3.6719e-01,  5.1155e-01,\n",
       "          2.2515e-01,  4.5560e-01,  2.2075e-01,  2.5566e-01,  3.9653e-01,\n",
       "          3.6347e-01,  2.5023e-01,  3.5295e-01,  3.9626e-01,  1.8832e-01,\n",
       "          9.1488e-02, -1.4247e-02,  3.3458e-01,  2.3136e-01, -1.5594e-01,\n",
       "         -4.1402e-01, -3.6743e-02, -3.8043e-01, -2.5838e-01,  6.6641e-02,\n",
       "         -7.5254e-02, -2.9376e-01, -2.4616e-01,  4.1705e-03, -1.1467e-01,\n",
       "         -5.8850e-02, -8.5694e-02, -7.5935e-02,  7.8767e-02, -2.0621e-01,\n",
       "          1.7142e-01,  2.1569e-01, -4.0482e-01, -1.1024e-01, -2.6733e-01,\n",
       "          1.4650e-01,  1.2285e-01,  9.9291e-02,  1.0571e-01, -2.8415e-02,\n",
       "          1.7755e-01, -1.6525e-01,  4.2671e-01, -1.1300e-01,  1.6585e-01,\n",
       "          1.2276e-01, -1.5552e-01, -3.3868e-02, -1.2518e-01, -1.0450e-01,\n",
       "          3.0160e-01, -2.7929e-01,  1.2376e-01,  1.9477e-01,  3.9703e-02,\n",
       "         -1.1759e-01, -3.7576e-02, -3.8561e-02, -2.2450e-01,  9.0587e-03,\n",
       "          4.0039e-01, -1.7462e-01,  1.2679e-01,  1.2052e-01,  1.4503e-02,\n",
       "          7.2368e-02, -2.2911e-02, -1.0229e-01, -2.4184e-01, -2.6660e-01,\n",
       "          6.9126e-02, -7.1969e-02, -2.1321e-02, -1.0584e-02, -8.0532e-02,\n",
       "         -1.1274e-01, -1.9391e-01, -8.1003e-02,  6.5148e-02,  7.0510e-02,\n",
       "          8.1063e-02, -7.5343e-02,  8.1484e-02, -2.6919e-01,  2.5432e-01,\n",
       "          4.4258e-02,  1.0812e-01, -4.5133e-02,  3.3619e-01, -2.2552e-01,\n",
       "         -2.1073e-01, -3.3512e-01, -1.1773e-02, -1.2269e-01,  1.7527e-01,\n",
       "          3.6220e-02, -7.2015e-03,  2.3288e-01, -1.2113e-01,  7.8424e-02,\n",
       "         -1.5986e-02, -3.4088e-01, -2.9910e-01,  2.7640e-02, -2.0679e-01,\n",
       "          1.2191e-02,  1.1560e-02, -1.9736e-01,  2.6542e-01, -1.3175e-01,\n",
       "         -9.1638e-02, -5.5544e-02, -3.2728e-02, -4.5655e-02, -9.6354e-02,\n",
       "         -7.1666e-02,  4.8999e-02, -2.6076e-02,  8.6134e-02, -1.2266e-01,\n",
       "          6.2707e-02, -3.1741e-01, -1.1749e-02, -1.1295e-02,  3.6409e-02,\n",
       "          2.3761e-01, -1.2505e-01, -8.9481e-02, -1.9531e-01, -1.8761e-01,\n",
       "         -8.0336e-03, -2.5456e-01, -2.6382e-01, -8.7645e-03, -1.0684e-01,\n",
       "          2.5339e-02, -2.0595e-01, -3.2660e-01, -3.8845e-02, -5.2941e-02,\n",
       "          4.7773e-02,  3.6050e-02, -2.9667e-01, -9.9204e-02,  2.1376e-01,\n",
       "         -6.2701e-02, -1.6944e-01, -5.1408e-02,  2.3885e-02, -7.6945e-02,\n",
       "          1.2495e-01,  1.4759e-01,  6.1426e-02, -7.2193e-02, -2.1667e-01,\n",
       "         -1.7329e-02, -3.6878e-01, -7.2065e-03, -2.8384e-01,  3.8618e-01,\n",
       "         -3.7536e-01, -1.1039e-01,  2.1628e-02, -8.5171e-02,  3.0927e-01,\n",
       "         -2.7605e-02, -9.6775e-02,  1.6450e-01, -2.5837e-01, -2.3381e-01,\n",
       "         -5.1041e-02, -3.1330e-01, -3.5100e-02,  2.6014e-01, -2.8547e-02,\n",
       "          4.0444e-02, -1.3048e-01, -4.3301e-02,  1.9448e-01,  1.9649e-02,\n",
       "         -2.8926e-01, -5.9976e-02, -1.3830e-01,  1.5823e-01,  1.2106e-01,\n",
       "         -2.6639e-03, -2.6532e-01,  1.6724e-01, -1.5998e-01,  4.8005e-01,\n",
       "         -1.6915e-01, -3.1967e-01, -6.6139e-02,  2.3338e-01, -3.9123e-02,\n",
       "          6.2504e-02, -1.5678e-01,  3.7560e-02, -3.7048e-01,  4.5680e-02,\n",
       "         -1.7966e-01, -1.6408e-01,  3.6574e-02, -1.8910e-01,  1.2427e-02,\n",
       "          9.3017e-02, -2.2461e-01,  1.5170e-01,  4.3417e-02, -4.6470e-02,\n",
       "         -1.5336e-01, -1.8671e-01,  1.2724e-01, -6.9469e-03,  1.3931e-01,\n",
       "         -2.9443e-01, -3.1339e-01, -1.6544e-01,  1.9363e-01, -1.6315e-01,\n",
       "         -1.1992e-01,  5.4212e-03, -7.6263e-02,  4.2482e-01, -2.8010e-01,\n",
       "         -1.0298e-01,  1.0719e-01, -2.1474e-01, -2.9642e-01,  1.9715e-01,\n",
       "         -3.2183e-02,  4.1459e-02, -3.9711e-02, -1.1227e-01, -3.1510e-02,\n",
       "         -5.2222e-03, -1.2478e-01, -2.6623e-01, -8.5785e-02, -2.6140e-01,\n",
       "         -2.8393e-01, -5.3811e-02, -1.6795e-01,  7.7608e-02, -9.4560e-02,\n",
       "          1.8537e-01, -1.3381e-01,  1.1088e-01, -1.0408e-01, -1.3889e-01,\n",
       "         -5.9964e-03, -1.5640e-01, -1.8662e-01, -2.1673e-02,  4.8620e-02,\n",
       "         -1.8421e-01,  2.7766e-01, -1.6006e-01, -8.4446e-02, -8.3540e-02,\n",
       "         -3.0732e-01,  1.9243e-01, -8.9779e-02,  4.1132e-02, -3.4977e-02,\n",
       "         -3.3654e-01, -1.1833e-01, -5.6749e-02, -1.0998e-01, -7.1620e-02,\n",
       "         -3.3300e-01,  6.9617e-02, -6.5482e-03,  3.0713e-01,  1.4317e-01,\n",
       "          2.5623e-01,  2.1503e-01, -1.2110e-01,  2.0372e-01, -3.1207e-01,\n",
       "          1.6217e-01, -2.1337e-01, -4.7750e-02, -1.7767e-01,  8.9443e-02,\n",
       "          4.3317e-03, -3.0448e-01, -3.8125e-02,  1.5997e-01,  1.5064e-01,\n",
       "         -2.7687e-01,  5.7003e-02, -1.8123e-02,  3.5218e-02,  4.0356e-01,\n",
       "         -3.3884e-01,  1.1523e-01,  1.4304e-01,  2.0349e-01,  1.4582e-01,\n",
       "         -2.2050e-01,  8.6250e-02, -6.1388e-02, -1.0967e-02,  2.5518e-02,\n",
       "         -2.6198e-02,  1.5040e-01, -1.2555e-01, -2.5882e-01, -1.5671e-01,\n",
       "         -3.1024e-01, -3.4687e-01, -2.1213e-01, -1.0582e-01, -5.4677e-02,\n",
       "          1.2054e-01,  1.5666e-01, -2.8328e-01, -3.4375e-02, -2.3331e-01,\n",
       "         -1.0096e-01, -1.5373e-02, -2.1569e-01,  1.5409e-01,  1.8641e-02,\n",
       "          1.7794e-01,  3.8752e-02, -7.4389e-02, -1.0127e-01, -8.4036e-02,\n",
       "         -1.8661e-01,  4.5518e-02, -2.3550e-01,  5.5479e-02, -3.1183e-01,\n",
       "         -2.0483e-01,  8.4746e-02, -2.9469e-01,  2.3071e-01, -8.6013e-02,\n",
       "         -4.2235e-02,  1.5816e-01, -3.2402e-01, -1.8819e-01, -1.8873e-01,\n",
       "         -4.4247e-02, -2.7022e-01,  2.3252e-01, -1.0359e-01, -2.2290e-01,\n",
       "          9.7211e-02,  7.8672e-02, -1.4051e-01, -3.1084e-02, -1.8090e-01,\n",
       "         -1.7161e-01,  2.3969e-02,  9.9099e-03, -7.8489e-02, -3.6927e-01,\n",
       "          5.8634e-02,  3.3587e-01, -1.6323e-02,  4.2122e-01,  7.0401e-02,\n",
       "         -2.5359e-01, -9.4064e-02,  1.5538e-01,  1.6284e-03, -1.9452e-01,\n",
       "          1.4135e-01,  6.7474e-02,  4.2442e-02, -6.8525e-02,  7.9882e-02,\n",
       "          1.1528e-01, -2.4441e-01, -1.3187e-01, -3.4780e-03, -4.2392e-02,\n",
       "          2.6853e-02,  1.4400e-01, -9.9624e-02, -2.6294e-01, -1.1432e-01,\n",
       "          2.3365e-01, -7.1575e-02,  4.7919e-02,  1.3444e-02, -7.8988e-02,\n",
       "         -2.6020e-01,  1.8627e-01, -1.6766e-01, -2.4382e-01, -2.3937e-03,\n",
       "         -1.1106e-01, -1.1733e-01, -2.1904e-03, -6.1306e-02, -6.8139e-02,\n",
       "         -6.0180e-02, -2.7575e-02, -1.6927e-01,  2.2123e-02,  5.1046e-03,\n",
       "          9.1977e-02, -4.2010e-01,  3.1531e-02, -1.0432e-01, -5.9499e-02,\n",
       "         -1.4227e-01, -2.2071e-03, -1.0201e-01, -2.0395e-01, -6.2591e-02,\n",
       "         -9.9255e-02, -5.9713e-02,  6.9466e-03,  3.9616e-02,  2.5549e-01,\n",
       "         -1.3089e-01,  1.3629e-01,  1.9560e-02, -1.3970e-01,  7.6565e-02,\n",
       "         -2.5834e-01, -1.4425e-01, -4.1600e-02,  6.4720e-02, -4.8293e-02,\n",
       "          3.9835e-02, -1.0735e-01, -1.9065e-01, -1.0297e-02, -4.6257e-01,\n",
       "         -1.1275e-01,  1.3461e-01,  2.3439e-02,  1.3629e-01, -1.1407e-01,\n",
       "         -5.8173e-02, -1.8470e-01, -1.0108e-01, -2.1546e-01,  8.1609e-02,\n",
       "         -1.1135e-01, -1.9366e-01, -2.5534e-02,  7.1518e-03,  8.3766e-02,\n",
       "          2.0059e-01, -2.4005e-01,  2.7818e-01,  1.1070e-01, -2.9542e-02,\n",
       "         -2.2901e-01,  2.2883e-01, -3.4692e-02,  5.7272e-02, -9.1160e-02,\n",
       "         -2.5467e-01,  6.2668e-03, -1.0490e-01, -9.0321e-02,  4.8394e-02,\n",
       "          8.3008e-02, -2.6141e-02,  6.4451e-02, -1.5710e-01, -8.5294e-02,\n",
       "         -7.5820e-02,  3.4173e-02, -3.7875e-02, -1.0733e-01,  9.8452e-02,\n",
       "          1.7418e-01,  3.8679e-01,  1.6042e-01, -1.2213e-01,  3.1008e-01,\n",
       "         -2.8292e-01, -2.6848e-01, -2.1315e-01,  8.9100e-02,  5.7202e-02,\n",
       "         -2.5249e-01, -3.5481e-02, -1.0943e-01, -1.4575e-01, -2.3862e-01,\n",
       "         -4.1532e-01, -1.7416e-01, -5.9082e-02,  6.0813e-02, -1.8539e-01,\n",
       "         -2.2777e-01, -1.4130e-01,  1.1217e-01, -9.2016e-02, -1.5192e-01,\n",
       "         -1.1104e-01,  4.2307e-01, -1.9202e-01, -2.6294e-01, -6.2639e-02,\n",
       "          7.3836e-03, -2.8055e-02, -3.0666e-02,  6.3371e-02,  3.2942e-03,\n",
       "         -3.9766e-02, -2.2645e-01, -1.0134e-01,  1.4264e-01, -3.6185e-01,\n",
       "         -1.7681e-01,  2.0689e-02, -6.9407e-02, -9.9117e-02,  1.5134e-01,\n",
       "         -1.6631e-01,  1.7213e-01,  1.5422e-01,  2.4922e-02,  1.6776e-01,\n",
       "         -2.8483e-02,  2.0558e-01, -6.2833e-02,  1.3241e-01,  1.6673e-01,\n",
       "          3.5344e-01, -1.4939e-01, -1.3477e-02,  4.9669e-02, -2.1777e-01,\n",
       "          1.1248e-01, -1.0244e-01,  4.4273e-02, -5.9534e-03, -1.0321e-01,\n",
       "          1.0159e-01, -1.7996e-01, -1.0837e-01, -1.0760e-01, -7.9204e-02,\n",
       "         -7.6940e-02,  1.8530e-01, -3.8072e-02, -3.2123e-01, -3.1962e-01,\n",
       "         -3.5619e-01, -2.7386e-01, -2.4828e-01, -2.1376e-01, -6.6686e-02,\n",
       "         -1.7495e-02, -2.3398e-01, -1.5790e-01, -3.8745e-01, -3.0398e-01,\n",
       "         -2.5111e-01,  2.1225e-02, -2.1580e-01, -1.6745e-01, -1.3299e-01,\n",
       "         -1.1569e-01, -1.4653e-01, -1.0050e-01, -1.0652e-01, -1.9919e-02,\n",
       "          4.3570e-03, -5.5960e-02,  6.1171e-02, -9.7138e-02,  1.0088e-02,\n",
       "         -1.7262e-01, -1.0604e-01, -1.0942e-01, -1.0560e-01, -8.3540e-02,\n",
       "         -1.3718e-01,  1.7382e-01, -1.2698e-01,  1.4962e-01, -3.2334e-01,\n",
       "         -4.1813e-02,  1.0611e-01, -2.8023e-01, -3.7135e-01, -2.9572e-01,\n",
       "         -3.0216e-01,  1.4628e-01, -1.2960e-01, -1.5928e-01,  7.4739e-02,\n",
       "          2.4275e-01,  5.4764e-02,  3.4404e-01, -7.5884e-02,  2.0467e-01,\n",
       "          1.6417e-01,  1.9391e-01,  2.8781e-01,  1.9518e-01,  3.3620e-01,\n",
       "          3.1126e-01, -2.1227e-01, -3.3151e-02, -1.8200e-01,  6.6010e-02,\n",
       "         -1.3311e-01, -1.4060e-01, -1.8008e-01,  1.1491e-01, -2.2763e-02,\n",
       "          1.4658e-02, -1.2218e-01, -1.3545e-01,  1.0391e-01, -5.5223e-02,\n",
       "          9.8317e-02,  2.7491e-02, -8.7586e-02, -1.7342e-01, -5.4462e-02]],\n",
       "       grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(rand_inp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-5.8660e-02, -2.1562e-01,  2.3406e-02, -3.3462e-02,  3.9180e-02,\n",
       "         -8.5186e-02, -1.1463e-04,  7.9720e-03,  1.4957e-01, -1.0326e-01,\n",
       "         -2.1049e-01, -1.6315e-02,  5.5162e-02, -1.1860e-01, -6.3085e-02,\n",
       "         -2.7938e-02,  3.4152e-02,  6.3728e-02,  3.6285e-02, -3.0951e-02,\n",
       "         -8.7539e-02,  2.1712e-02,  3.0056e-02,  3.4307e-02,  1.0682e-01,\n",
       "         -1.5340e-01, -1.5469e-01, -1.1974e-01, -2.1853e-01, -5.6144e-02,\n",
       "         -1.3605e-01, -1.5758e-01, -2.1487e-01, -1.0316e-01, -5.9025e-02,\n",
       "          3.3612e-03, -7.9679e-03, -7.6521e-02, -9.4139e-02,  1.3740e-02,\n",
       "         -2.9567e-01, -2.7352e-01, -4.7620e-02, -1.7323e-01, -1.9008e-01,\n",
       "         -2.1079e-01, -3.9792e-02, -2.6766e-01,  1.3814e-02, -1.0559e-01,\n",
       "         -5.4808e-03,  1.2212e-01, -2.3230e-01, -1.2079e-01, -2.3697e-01,\n",
       "         -2.6356e-01, -2.1494e-01, -8.0847e-02,  1.5610e-02, -9.1518e-02,\n",
       "         -1.0659e-01, -1.2309e-01, -7.3370e-02, -1.7430e-01, -3.1835e-02,\n",
       "         -2.3297e-01, -7.5843e-02, -1.7234e-01, -1.7088e-01, -1.0650e-01,\n",
       "         -1.2633e-01, -2.2436e-01, -2.5948e-01,  1.4899e-03, -1.3207e-01,\n",
       "         -1.0716e-01,  5.9147e-02, -3.9363e-02, -6.7278e-03, -1.4659e-01,\n",
       "         -7.3322e-03,  1.1680e-01, -1.1496e-02, -2.6054e-01, -6.1695e-02,\n",
       "          5.2757e-02, -6.4250e-02,  3.9342e-02, -1.2961e-02,  2.4374e-02,\n",
       "         -1.1947e-01,  1.6585e-01, -1.1482e-01,  1.1079e-01,  2.1573e-02,\n",
       "         -9.6015e-02,  1.0226e-01, -9.0010e-02, -1.3473e-02, -7.0562e-02,\n",
       "         -8.0098e-02,  1.4754e-01, -1.7342e-01,  1.7592e-01,  1.4815e-01,\n",
       "          3.9071e-01,  2.5337e-01, -1.3872e-01, -2.2159e-01, -3.1420e-01,\n",
       "         -3.9102e-01, -1.7172e-01,  1.1301e-01, -9.7920e-02, -8.3593e-03,\n",
       "         -2.3097e-01, -2.3011e-01, -2.3369e-01, -1.0459e-01, -1.0186e-02,\n",
       "         -3.5810e-01, -4.5051e-02, -1.7585e-01, -2.1008e-01, -1.8887e-01,\n",
       "         -3.0079e-01,  2.2413e-02,  4.7506e-02, -3.3614e-02,  9.0848e-02,\n",
       "         -2.2051e-01, -4.8445e-03,  1.0666e-01, -5.1658e-02, -1.8597e-02,\n",
       "          6.7795e-02, -1.3586e-01, -1.0246e-01, -6.0281e-02, -1.6838e-01,\n",
       "         -1.7915e-01, -1.1169e-01, -1.4231e-01, -6.1034e-02, -6.7012e-02,\n",
       "         -2.0487e-02, -2.2684e-02,  1.6808e-01,  1.2276e-01,  1.4591e-01,\n",
       "          9.2301e-03,  2.6103e-01, -1.1816e-01,  2.7369e-01,  2.0906e-01,\n",
       "          1.8619e-01,  2.2127e-01,  2.1149e-01, -1.8778e-01,  2.2451e-01,\n",
       "          2.5008e-01,  1.8428e-01,  4.5649e-02,  1.2698e-01,  9.3232e-02,\n",
       "         -2.2601e-01, -1.7165e-02, -2.1441e-01,  3.5009e-01,  3.3781e-01,\n",
       "          4.9500e-01,  1.2219e-01,  1.1443e-01,  4.3969e-01,  3.4117e-01,\n",
       "          8.9539e-02, -1.3635e-03,  6.5815e-02,  2.9517e-01,  8.6058e-02,\n",
       "          2.1255e-01,  2.8019e-01,  2.9538e-01,  8.0565e-02,  5.3600e-01,\n",
       "          2.4076e-01,  3.7005e-01,  1.1593e-01,  2.8775e-01,  2.0053e-01,\n",
       "          2.6629e-01,  2.9318e-01,  3.2647e-01, -1.2192e-01,  2.2759e-01,\n",
       "         -3.4552e-02,  6.5858e-02,  1.2184e-01, -1.1611e-01,  1.7002e-01,\n",
       "          2.2899e-01,  1.5947e-01,  4.9881e-01,  6.2806e-01,  3.1204e-01,\n",
       "         -8.7780e-02,  1.2163e-01,  4.1696e-01,  3.6783e-01,  3.6908e-01,\n",
       "         -5.8080e-03,  3.2986e-01,  1.1076e-01,  1.7374e-01,  4.6612e-02,\n",
       "          1.6096e-01,  4.9993e-01,  6.4068e-03,  3.6882e-01,  1.7274e-01,\n",
       "          3.5628e-02, -7.3599e-03,  7.3059e-01,  1.9854e-01,  5.5261e-02,\n",
       "          1.1354e-01,  9.7396e-02, -2.4459e-01,  4.8655e-01,  6.3419e-01,\n",
       "          1.7407e-01,  1.3035e-01, -8.0233e-02,  2.0570e-01,  2.0051e-01,\n",
       "          1.3842e-01,  1.2668e-01,  2.8960e-01, -4.9516e-02,  9.5622e-02,\n",
       "          3.9276e-02,  8.1422e-02,  1.9139e-01,  3.2768e-01,  1.7151e-01,\n",
       "          2.2235e-01,  1.7143e-01,  3.3260e-01,  2.4008e-01,  4.3723e-01,\n",
       "          3.6052e-01, -1.4622e-02, -3.2805e-01,  2.0522e-01,  2.7127e-01,\n",
       "          6.7445e-02, -6.2343e-02,  4.6008e-01,  4.7580e-01,  3.9206e-01,\n",
       "          3.7176e-01,  3.8617e-01,  3.2434e-02,  3.5646e-01,  9.5142e-02,\n",
       "          2.9630e-01,  2.8891e-01,  3.1543e-01, -1.8976e-01,  4.0708e-01,\n",
       "          5.2407e-01,  3.4148e-01,  2.8393e-01,  3.4582e-01,  3.4763e-01,\n",
       "          1.0872e-01,  2.4034e-01,  4.6310e-01,  1.9363e-01,  6.1256e-01,\n",
       "          2.0091e-01,  3.7698e-01,  4.2466e-01,  3.3629e-01,  5.3141e-01,\n",
       "          3.2738e-01,  2.8749e-01,  4.1361e-01, -1.1674e-01,  1.4269e-01,\n",
       "         -2.1067e-01,  2.2781e-01, -2.2370e-02, -1.2583e-01,  3.7178e-01,\n",
       "          2.8018e-01,  7.3633e-01,  2.1212e-01,  3.3912e-01,  3.3061e-01,\n",
       "         -3.6240e-01, -2.4994e-01, -2.1300e-01, -1.8365e-01, -3.3815e-01,\n",
       "         -1.3740e-01,  8.9831e-02, -2.7666e-01, -2.8130e-01, -1.7547e-01,\n",
       "         -1.7428e-01, -2.8445e-01, -2.8936e-01, -1.6512e-01, -1.0570e-01,\n",
       "         -1.5370e-01,  1.1436e-03, -2.3495e-01, -2.5799e-01, -2.0066e-01,\n",
       "         -3.2411e-01, -2.0960e-01, -2.8850e-01, -4.0092e-01, -7.4763e-02,\n",
       "         -2.8938e-01, -3.2552e-01, -1.1810e-01, -4.2431e-02, -2.0727e-01,\n",
       "          5.6929e-02,  1.0431e-01,  1.3905e-01,  1.8015e-01,  1.8536e-01,\n",
       "          3.3983e-01,  2.8186e-01,  1.9269e-01,  3.2566e-01,  3.4010e-02,\n",
       "         -2.3496e-01,  1.5395e-01,  8.7668e-02, -5.2797e-02,  6.4299e-02,\n",
       "          5.0769e-01,  4.1322e-01,  7.1562e-02,  6.0935e-02,  7.6061e-02,\n",
       "          1.0319e-01,  2.1294e-01,  1.5302e-01,  1.2171e-01,  1.4456e-01,\n",
       "          3.5053e-02,  6.6327e-01,  3.4303e-01,  5.8247e-01,  7.0268e-01,\n",
       "          1.7966e-01,  1.2798e-01,  2.7284e-01, -4.3191e-02,  4.7465e-01,\n",
       "          1.5055e-01,  1.6825e-01,  1.4819e-01,  3.9967e-01,  9.4724e-04,\n",
       "          1.7639e-01,  3.2570e-01,  2.5935e-01,  3.6719e-01,  5.1155e-01,\n",
       "          2.2515e-01,  4.5560e-01,  2.2075e-01,  2.5566e-01,  3.9653e-01,\n",
       "          3.6347e-01,  2.5023e-01,  3.5295e-01,  3.9626e-01,  1.8832e-01,\n",
       "          9.1488e-02, -1.4247e-02,  3.3458e-01,  2.3136e-01, -1.5594e-01,\n",
       "         -4.1402e-01, -3.6743e-02, -3.8043e-01, -2.5838e-01,  6.6641e-02,\n",
       "         -7.5254e-02, -2.9376e-01, -2.4616e-01,  4.1705e-03, -1.1467e-01,\n",
       "         -5.8850e-02, -8.5694e-02, -7.5935e-02,  7.8767e-02, -2.0621e-01,\n",
       "          1.7142e-01,  2.1569e-01, -4.0482e-01, -1.1024e-01, -2.6733e-01,\n",
       "          1.4650e-01,  1.2285e-01,  9.9291e-02,  1.0571e-01, -2.8415e-02,\n",
       "          1.7755e-01, -1.6525e-01,  4.2671e-01, -1.1300e-01,  1.6585e-01,\n",
       "          1.2276e-01, -1.5552e-01, -3.3868e-02, -1.2518e-01, -1.0450e-01,\n",
       "          3.0160e-01, -2.7929e-01,  1.2376e-01,  1.9477e-01,  3.9703e-02,\n",
       "         -1.1759e-01, -3.7576e-02, -3.8561e-02, -2.2450e-01,  9.0587e-03,\n",
       "          4.0039e-01, -1.7462e-01,  1.2679e-01,  1.2052e-01,  1.4503e-02,\n",
       "          7.2368e-02, -2.2911e-02, -1.0229e-01, -2.4184e-01, -2.6660e-01,\n",
       "          6.9126e-02, -7.1969e-02, -2.1321e-02, -1.0584e-02, -8.0532e-02,\n",
       "         -1.1274e-01, -1.9391e-01, -8.1003e-02,  6.5148e-02,  7.0510e-02,\n",
       "          8.1063e-02, -7.5343e-02,  8.1484e-02, -2.6919e-01,  2.5432e-01,\n",
       "          4.4258e-02,  1.0812e-01, -4.5133e-02,  3.3619e-01, -2.2552e-01,\n",
       "         -2.1073e-01, -3.3512e-01, -1.1773e-02, -1.2269e-01,  1.7527e-01,\n",
       "          3.6220e-02, -7.2015e-03,  2.3288e-01, -1.2113e-01,  7.8424e-02,\n",
       "         -1.5986e-02, -3.4088e-01, -2.9910e-01,  2.7640e-02, -2.0679e-01,\n",
       "          1.2191e-02,  1.1560e-02, -1.9736e-01,  2.6542e-01, -1.3175e-01,\n",
       "         -9.1638e-02, -5.5544e-02, -3.2728e-02, -4.5655e-02, -9.6354e-02,\n",
       "         -7.1666e-02,  4.8999e-02, -2.6076e-02,  8.6134e-02, -1.2266e-01,\n",
       "          6.2707e-02, -3.1741e-01, -1.1749e-02, -1.1295e-02,  3.6409e-02,\n",
       "          2.3761e-01, -1.2505e-01, -8.9481e-02, -1.9531e-01, -1.8761e-01,\n",
       "         -8.0336e-03, -2.5456e-01, -2.6382e-01, -8.7645e-03, -1.0684e-01,\n",
       "          2.5339e-02, -2.0595e-01, -3.2660e-01, -3.8845e-02, -5.2941e-02,\n",
       "          4.7773e-02,  3.6050e-02, -2.9667e-01, -9.9204e-02,  2.1376e-01,\n",
       "         -6.2701e-02, -1.6944e-01, -5.1408e-02,  2.3885e-02, -7.6945e-02,\n",
       "          1.2495e-01,  1.4759e-01,  6.1426e-02, -7.2193e-02, -2.1667e-01,\n",
       "         -1.7329e-02, -3.6878e-01, -7.2065e-03, -2.8384e-01,  3.8618e-01,\n",
       "         -3.7536e-01, -1.1039e-01,  2.1628e-02, -8.5171e-02,  3.0927e-01,\n",
       "         -2.7605e-02, -9.6775e-02,  1.6450e-01, -2.5837e-01, -2.3381e-01,\n",
       "         -5.1041e-02, -3.1330e-01, -3.5100e-02,  2.6014e-01, -2.8547e-02,\n",
       "          4.0444e-02, -1.3048e-01, -4.3301e-02,  1.9448e-01,  1.9649e-02,\n",
       "         -2.8926e-01, -5.9976e-02, -1.3830e-01,  1.5823e-01,  1.2106e-01,\n",
       "         -2.6639e-03, -2.6532e-01,  1.6724e-01, -1.5998e-01,  4.8005e-01,\n",
       "         -1.6915e-01, -3.1967e-01, -6.6139e-02,  2.3338e-01, -3.9123e-02,\n",
       "          6.2504e-02, -1.5678e-01,  3.7560e-02, -3.7048e-01,  4.5680e-02,\n",
       "         -1.7966e-01, -1.6408e-01,  3.6574e-02, -1.8910e-01,  1.2427e-02,\n",
       "          9.3017e-02, -2.2461e-01,  1.5170e-01,  4.3417e-02, -4.6470e-02,\n",
       "         -1.5336e-01, -1.8671e-01,  1.2724e-01, -6.9469e-03,  1.3931e-01,\n",
       "         -2.9443e-01, -3.1339e-01, -1.6544e-01,  1.9363e-01, -1.6315e-01,\n",
       "         -1.1992e-01,  5.4212e-03, -7.6263e-02,  4.2482e-01, -2.8010e-01,\n",
       "         -1.0298e-01,  1.0719e-01, -2.1474e-01, -2.9642e-01,  1.9715e-01,\n",
       "         -3.2183e-02,  4.1459e-02, -3.9711e-02, -1.1227e-01, -3.1510e-02,\n",
       "         -5.2222e-03, -1.2478e-01, -2.6623e-01, -8.5785e-02, -2.6140e-01,\n",
       "         -2.8393e-01, -5.3811e-02, -1.6795e-01,  7.7608e-02, -9.4560e-02,\n",
       "          1.8537e-01, -1.3381e-01,  1.1088e-01, -1.0408e-01, -1.3889e-01,\n",
       "         -5.9964e-03, -1.5640e-01, -1.8662e-01, -2.1673e-02,  4.8620e-02,\n",
       "         -1.8421e-01,  2.7766e-01, -1.6006e-01, -8.4446e-02, -8.3540e-02,\n",
       "         -3.0732e-01,  1.9243e-01, -8.9779e-02,  4.1132e-02, -3.4977e-02,\n",
       "         -3.3654e-01, -1.1833e-01, -5.6749e-02, -1.0998e-01, -7.1620e-02,\n",
       "         -3.3300e-01,  6.9617e-02, -6.5482e-03,  3.0713e-01,  1.4317e-01,\n",
       "          2.5623e-01,  2.1503e-01, -1.2110e-01,  2.0372e-01, -3.1207e-01,\n",
       "          1.6217e-01, -2.1337e-01, -4.7750e-02, -1.7767e-01,  8.9443e-02,\n",
       "          4.3317e-03, -3.0448e-01, -3.8125e-02,  1.5997e-01,  1.5064e-01,\n",
       "         -2.7687e-01,  5.7003e-02, -1.8123e-02,  3.5218e-02,  4.0356e-01,\n",
       "         -3.3884e-01,  1.1523e-01,  1.4304e-01,  2.0349e-01,  1.4582e-01,\n",
       "         -2.2050e-01,  8.6250e-02, -6.1388e-02, -1.0967e-02,  2.5518e-02,\n",
       "         -2.6198e-02,  1.5040e-01, -1.2555e-01, -2.5882e-01, -1.5671e-01,\n",
       "         -3.1024e-01, -3.4687e-01, -2.1213e-01, -1.0582e-01, -5.4677e-02,\n",
       "          1.2054e-01,  1.5666e-01, -2.8328e-01, -3.4375e-02, -2.3331e-01,\n",
       "         -1.0096e-01, -1.5373e-02, -2.1569e-01,  1.5409e-01,  1.8641e-02,\n",
       "          1.7794e-01,  3.8752e-02, -7.4389e-02, -1.0127e-01, -8.4036e-02,\n",
       "         -1.8661e-01,  4.5518e-02, -2.3550e-01,  5.5479e-02, -3.1183e-01,\n",
       "         -2.0483e-01,  8.4746e-02, -2.9469e-01,  2.3071e-01, -8.6013e-02,\n",
       "         -4.2235e-02,  1.5816e-01, -3.2402e-01, -1.8819e-01, -1.8873e-01,\n",
       "         -4.4247e-02, -2.7022e-01,  2.3252e-01, -1.0359e-01, -2.2290e-01,\n",
       "          9.7211e-02,  7.8672e-02, -1.4051e-01, -3.1084e-02, -1.8090e-01,\n",
       "         -1.7161e-01,  2.3969e-02,  9.9099e-03, -7.8489e-02, -3.6927e-01,\n",
       "          5.8634e-02,  3.3587e-01, -1.6323e-02,  4.2122e-01,  7.0401e-02,\n",
       "         -2.5359e-01, -9.4064e-02,  1.5538e-01,  1.6284e-03, -1.9452e-01,\n",
       "          1.4135e-01,  6.7474e-02,  4.2442e-02, -6.8525e-02,  7.9882e-02,\n",
       "          1.1528e-01, -2.4441e-01, -1.3187e-01, -3.4780e-03, -4.2392e-02,\n",
       "          2.6853e-02,  1.4400e-01, -9.9624e-02, -2.6294e-01, -1.1432e-01,\n",
       "          2.3365e-01, -7.1575e-02,  4.7919e-02,  1.3444e-02, -7.8988e-02,\n",
       "         -2.6020e-01,  1.8627e-01, -1.6766e-01, -2.4382e-01, -2.3937e-03,\n",
       "         -1.1106e-01, -1.1733e-01, -2.1904e-03, -6.1306e-02, -6.8139e-02,\n",
       "         -6.0180e-02, -2.7575e-02, -1.6927e-01,  2.2123e-02,  5.1046e-03,\n",
       "          9.1977e-02, -4.2010e-01,  3.1531e-02, -1.0432e-01, -5.9499e-02,\n",
       "         -1.4227e-01, -2.2071e-03, -1.0201e-01, -2.0395e-01, -6.2591e-02,\n",
       "         -9.9255e-02, -5.9713e-02,  6.9466e-03,  3.9616e-02,  2.5549e-01,\n",
       "         -1.3089e-01,  1.3629e-01,  1.9560e-02, -1.3970e-01,  7.6565e-02,\n",
       "         -2.5834e-01, -1.4425e-01, -4.1600e-02,  6.4720e-02, -4.8293e-02,\n",
       "          3.9835e-02, -1.0735e-01, -1.9065e-01, -1.0297e-02, -4.6257e-01,\n",
       "         -1.1275e-01,  1.3461e-01,  2.3439e-02,  1.3629e-01, -1.1407e-01,\n",
       "         -5.8173e-02, -1.8470e-01, -1.0108e-01, -2.1546e-01,  8.1609e-02,\n",
       "         -1.1135e-01, -1.9366e-01, -2.5534e-02,  7.1518e-03,  8.3766e-02,\n",
       "          2.0059e-01, -2.4005e-01,  2.7818e-01,  1.1070e-01, -2.9542e-02,\n",
       "         -2.2901e-01,  2.2883e-01, -3.4692e-02,  5.7272e-02, -9.1160e-02,\n",
       "         -2.5467e-01,  6.2668e-03, -1.0490e-01, -9.0321e-02,  4.8394e-02,\n",
       "          8.3008e-02, -2.6141e-02,  6.4451e-02, -1.5710e-01, -8.5294e-02,\n",
       "         -7.5820e-02,  3.4173e-02, -3.7875e-02, -1.0733e-01,  9.8452e-02,\n",
       "          1.7418e-01,  3.8679e-01,  1.6042e-01, -1.2213e-01,  3.1008e-01,\n",
       "         -2.8292e-01, -2.6848e-01, -2.1315e-01,  8.9100e-02,  5.7202e-02,\n",
       "         -2.5249e-01, -3.5481e-02, -1.0943e-01, -1.4575e-01, -2.3862e-01,\n",
       "         -4.1532e-01, -1.7416e-01, -5.9082e-02,  6.0813e-02, -1.8539e-01,\n",
       "         -2.2777e-01, -1.4130e-01,  1.1217e-01, -9.2016e-02, -1.5192e-01,\n",
       "         -1.1104e-01,  4.2307e-01, -1.9202e-01, -2.6294e-01, -6.2639e-02,\n",
       "          7.3836e-03, -2.8055e-02, -3.0666e-02,  6.3371e-02,  3.2942e-03,\n",
       "         -3.9766e-02, -2.2645e-01, -1.0134e-01,  1.4264e-01, -3.6185e-01,\n",
       "         -1.7681e-01,  2.0689e-02, -6.9407e-02, -9.9117e-02,  1.5134e-01,\n",
       "         -1.6631e-01,  1.7213e-01,  1.5422e-01,  2.4922e-02,  1.6776e-01,\n",
       "         -2.8483e-02,  2.0558e-01, -6.2833e-02,  1.3241e-01,  1.6673e-01,\n",
       "          3.5344e-01, -1.4939e-01, -1.3477e-02,  4.9669e-02, -2.1777e-01,\n",
       "          1.1248e-01, -1.0244e-01,  4.4273e-02, -5.9534e-03, -1.0321e-01,\n",
       "          1.0159e-01, -1.7996e-01, -1.0837e-01, -1.0760e-01, -7.9204e-02,\n",
       "         -7.6940e-02,  1.8530e-01, -3.8072e-02, -3.2123e-01, -3.1962e-01,\n",
       "         -3.5619e-01, -2.7386e-01, -2.4828e-01, -2.1376e-01, -6.6686e-02,\n",
       "         -1.7495e-02, -2.3398e-01, -1.5790e-01, -3.8745e-01, -3.0398e-01,\n",
       "         -2.5111e-01,  2.1225e-02, -2.1580e-01, -1.6745e-01, -1.3299e-01,\n",
       "         -1.1569e-01, -1.4653e-01, -1.0050e-01, -1.0652e-01, -1.9919e-02,\n",
       "          4.3570e-03, -5.5960e-02,  6.1171e-02, -9.7138e-02,  1.0088e-02,\n",
       "         -1.7262e-01, -1.0604e-01, -1.0942e-01, -1.0560e-01, -8.3540e-02,\n",
       "         -1.3718e-01,  1.7382e-01, -1.2698e-01,  1.4962e-01, -3.2334e-01,\n",
       "         -4.1813e-02,  1.0611e-01, -2.8023e-01, -3.7135e-01, -2.9572e-01,\n",
       "         -3.0216e-01,  1.4628e-01, -1.2960e-01, -1.5928e-01,  7.4739e-02,\n",
       "          2.4275e-01,  5.4764e-02,  3.4404e-01, -7.5884e-02,  2.0467e-01,\n",
       "          1.6417e-01,  1.9391e-01,  2.8781e-01,  1.9518e-01,  3.3620e-01,\n",
       "          3.1126e-01, -2.1227e-01, -3.3151e-02, -1.8200e-01,  6.6010e-02,\n",
       "         -1.3311e-01, -1.4060e-01, -1.8008e-01,  1.1491e-01, -2.2763e-02,\n",
       "          1.4658e-02, -1.2218e-01, -1.3545e-01,  1.0391e-01, -5.5223e-02,\n",
       "          9.8317e-02,  2.7491e-02, -8.7586e-02, -1.7342e-01, -5.4462e-02]],\n",
       "       grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_model(rand_inp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1000)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.sum(res_model(rand_inp) == model(rand_inp))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ResNet50Classifier(num_labels=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1395,  0.0264,  0.1792, -0.1296, -0.0041, -0.0005, -0.0438, -0.0258,\n",
       "          0.1141,  0.1187,  0.0411, -0.0523,  0.0065,  0.1429,  0.0344, -0.0276,\n",
       "          0.0235, -0.0239, -0.0538, -0.0391,  0.0514, -0.0932,  0.1188, -0.0155,\n",
       "          0.0672, -0.0913,  0.0072, -0.0363,  0.0946, -0.0158, -0.0418,  0.0177,\n",
       "          0.0341,  0.0065, -0.1130,  0.0176,  0.0417,  0.0310,  0.0085,  0.1447,\n",
       "         -0.0054,  0.0998,  0.0856, -0.0721, -0.0407,  0.1116, -0.0731,  0.0410,\n",
       "         -0.0611,  0.1096,  0.0182, -0.1676,  0.0489,  0.0184, -0.0820,  0.1279,\n",
       "          0.0177,  0.0916, -0.1222,  0.0749, -0.0103,  0.0689,  0.0447, -0.0675,\n",
       "         -0.1039, -0.0654, -0.0503, -0.0081, -0.0439, -0.0475,  0.0012,  0.0877,\n",
       "         -0.1784,  0.0677, -0.1436, -0.0046, -0.1328, -0.0223, -0.0186,  0.0434,\n",
       "          0.1582, -0.0115, -0.0265,  0.0889, -0.0101, -0.0440,  0.0358,  0.1964,\n",
       "          0.0060,  0.0120,  0.0507,  0.0615,  0.0233,  0.1072,  0.0557,  0.1296,\n",
       "          0.0154,  0.0466,  0.0560, -0.1198]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(rand_inp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SwinClassifier(nn.Module):\n",
    "    def __init__(self, num_labels, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        model = swin_v2_b(weights=Swin_V2_B_Weights)\n",
    "        self.repr_model = nn.Sequential(*list(model.children())[:-1])\n",
    "        self.classifier = nn.Linear(in_features=1024, out_features=num_labels, bias=True)\n",
    "        if num_labels == 1000:\n",
    "            self.classifier = model.head\n",
    "\n",
    "    def forward(self, x):\n",
    "        repr = self.repr_model(x)\n",
    "        return self.classifier(repr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "swin_trmodel = SwinClassifier(num_labels=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SwinClassifier(\n",
       "  (repr_model): Sequential(\n",
       "    (0): Sequential(\n",
       "      (0): Sequential(\n",
       "        (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))\n",
       "        (1): Permute()\n",
       "        (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=128, out_features=384, bias=True)\n",
       "            (proj): Linear(in_features=128, out_features=128, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=4, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.0, mode=row)\n",
       "          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=128, out_features=512, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=512, out_features=128, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (1): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=128, out_features=384, bias=True)\n",
       "            (proj): Linear(in_features=128, out_features=128, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=4, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.021739130434782608, mode=row)\n",
       "          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=128, out_features=512, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=512, out_features=128, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (2): PatchMergingV2(\n",
       "        (reduction): Linear(in_features=512, out_features=256, bias=False)\n",
       "        (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "      (3): Sequential(\n",
       "        (0): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=256, out_features=768, bias=True)\n",
       "            (proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=8, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.043478260869565216, mode=row)\n",
       "          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=256, out_features=1024, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=1024, out_features=256, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (1): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=256, out_features=768, bias=True)\n",
       "            (proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=8, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.06521739130434782, mode=row)\n",
       "          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=256, out_features=1024, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=1024, out_features=256, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (4): PatchMergingV2(\n",
       "        (reduction): Linear(in_features=1024, out_features=512, bias=False)\n",
       "        (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "      (5): Sequential(\n",
       "        (0): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.08695652173913043, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (1): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.10869565217391304, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (2): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.13043478260869565, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (3): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.15217391304347827, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (4): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.17391304347826086, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (5): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.1956521739130435, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (6): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.21739130434782608, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (7): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.2391304347826087, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (8): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.2608695652173913, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (9): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.2826086956521739, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (10): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.30434782608695654, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (11): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.32608695652173914, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (12): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.34782608695652173, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (13): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.3695652173913043, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (14): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.391304347826087, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (15): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.41304347826086957, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (16): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.43478260869565216, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (17): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=512, out_features=1536, bias=True)\n",
       "            (proj): Linear(in_features=512, out_features=512, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=16, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.45652173913043476, mode=row)\n",
       "          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=512, out_features=2048, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=2048, out_features=512, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (6): PatchMergingV2(\n",
       "        (reduction): Linear(in_features=2048, out_features=1024, bias=False)\n",
       "        (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "      (7): Sequential(\n",
       "        (0): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=32, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.4782608695652174, mode=row)\n",
       "          (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (1): SwinTransformerBlockV2(\n",
       "          (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ShiftedWindowAttentionV2(\n",
       "            (qkv): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (cpb_mlp): Sequential(\n",
       "              (0): Linear(in_features=2, out_features=512, bias=True)\n",
       "              (1): ReLU(inplace=True)\n",
       "              (2): Linear(in_features=512, out_features=32, bias=False)\n",
       "            )\n",
       "          )\n",
       "          (stochastic_depth): StochasticDepth(p=0.5, mode=row)\n",
       "          (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLP(\n",
       "            (0): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "            (2): Dropout(p=0.0, inplace=False)\n",
       "            (3): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "            (4): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "    (2): Permute()\n",
       "    (3): AdaptiveAvgPool2d(output_size=1)\n",
       "    (4): Flatten(start_dim=1, end_dim=-1)\n",
       "  )\n",
       "  (classifier): Linear(in_features=1024, out_features=1000, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "swin_trmodel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([21])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "swin_trmodel.eval()\n",
    "torch.argmax(swin_trmodel(rand_inp), dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([21])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "swin_model.eval()\n",
    "torch.argmax(swin_model(rand_inp), dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1000)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.sum(swin_trmodel(rand_inp) == swin_model(rand_inp))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "neel",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
