{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "56027773",
   "metadata": {},
   "source": [
    "# Model Exploration\n",
    "Exploring model architectures to create AD variants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "18c1e255",
   "metadata": {},
   "outputs": [],
   "source": [
    "import timm\n",
    "import torch\n",
    "from torchinfo import summary\n",
    "from interpretable_regnety import InterpretableRegNetY\n",
    "from interpretable_efficientnetv2 import InterpretableEfficientNetV2\n",
    "from interpretable_mobilenetv4 import InterpretableMobileNetV4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b9a58de3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ImageClassification(\n",
       "    crop_size=[224]\n",
       "    resize_size=[232]\n",
       "    mean=[0.485, 0.456, 0.406]\n",
       "    std=[0.229, 0.224, 0.225]\n",
       "    interpolation=InterpolationMode.BILINEAR\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torchvision.models import get_model, ResNet50_Weights\n",
    "ResNet50_Weights.DEFAULT.transforms()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cd998c37",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MobileNetV3(\n",
       "  (conv_stem): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "  (bn1): BatchNormAct2d(\n",
       "    24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "    (drop): Identity()\n",
       "    (act): ReLU(inplace=True)\n",
       "  )\n",
       "  (blocks): Sequential(\n",
       "    (0): Sequential(\n",
       "      (0): EdgeResidual(\n",
       "        (conv_exp): Conv2d(24, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "        (aa): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (se): Identity()\n",
       "        (conv_pwl): Conv2d(96, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "    )\n",
       "    (1): Sequential(\n",
       "      (0): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(48, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(192, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=192, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "          (aa): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (1): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "    )\n",
       "    (2): Sequential(\n",
       "      (0): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(384, 384, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=384, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "          (aa): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(384, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (1): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (2): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (3): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (4): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(768, 768, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=768, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (5): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(192, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=192, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (6): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(192, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=192, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (7): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(192, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=192, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (8): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(192, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=192, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (9): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(192, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=192, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (10): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): Identity()\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "    )\n",
       "    (3): Sequential(\n",
       "      (0): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(192, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=192, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(768, 768, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=768, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "          (aa): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(768, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (1): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 2048, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2048, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (2): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 2048, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2048, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (3): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 2048, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2048, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (4): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): Identity()\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (5): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2048, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (6): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): Identity()\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (7): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): Identity()\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (8): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2048, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (9): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 2048, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2048, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (10): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): Identity()\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (11): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): Identity()\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (12): UniversalInvertedResidual(\n",
       "        (dw_start): ConvNormAct(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (pw_exp): ConvNormAct(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "        )\n",
       "        (dw_mid): Identity()\n",
       "        (se): Identity()\n",
       "        (pw_proj): ConvNormAct(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNormAct2d(\n",
       "            512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "        )\n",
       "        (dw_end): Identity()\n",
       "        (layer_scale): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "    )\n",
       "    (4): Sequential(\n",
       "      (0): ConvBnAct(\n",
       "        (conv): Conv2d(512, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Identity())\n",
       "  (conv_head): Conv2d(960, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "  (norm_head): BatchNormAct2d(\n",
       "    1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "    (drop): Identity()\n",
       "    (act): ReLU(inplace=True)\n",
       "  )\n",
       "  (act2): Identity()\n",
       "  (flatten): Flatten(start_dim=1, end_dim=-1)\n",
       "  (classifier): Linear(in_features=1280, out_features=1000, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = timm.create_model('mobilenetv4_conv_aa_large.e230_r448_in12k_ft_in1k', pretrained=True)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e2bfda63",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n",
      "Identity()\n"
     ]
    }
   ],
   "source": [
    "for block in model.blocks:\n",
    "    for stage in block:\n",
    "        if hasattr(stage, 'dw_end'):\n",
    "            print(stage.dw_end)\n",
    "        if hasattr(stage, 'layer_scale'):\n",
    "            print(stage.layer_scale)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "72c28b11",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Compose(\n",
       "    Resize(size=471, interpolation=bicubic, max_size=None, antialias=True)\n",
       "    CenterCrop(size=(448, 448))\n",
       "    MaybeToTensor()\n",
       "    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "13be726c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "====================================================================================================\n",
       "Layer (type:depth-idx)                             Output Shape              Param #\n",
       "====================================================================================================\n",
       "MobileNetV3                                        [1, 1000]                 --\n",
       "├─Conv2d: 1-1                                      [1, 24, 224, 224]         648\n",
       "├─BatchNormAct2d: 1-2                              [1, 24, 224, 224]         48\n",
       "│    └─Identity: 2-1                               [1, 24, 224, 224]         --\n",
       "│    └─ReLU: 2-2                                   [1, 24, 224, 224]         --\n",
       "├─Sequential: 1-3                                  [1, 960, 14, 14]          --\n",
       "│    └─Sequential: 2-3                             [1, 48, 112, 112]         --\n",
       "│    │    └─EdgeResidual: 3-1                      [1, 48, 112, 112]         25,632\n",
       "│    └─Sequential: 2-4                             [1, 96, 56, 56]           --\n",
       "│    │    └─UniversalInvertedResidual: 3-2         [1, 96, 56, 56]           33,936\n",
       "│    │    └─UniversalInvertedResidual: 3-3         [1, 96, 56, 56]           79,968\n",
       "│    └─Sequential: 2-5                             [1, 192, 28, 28]          --\n",
       "│    │    └─UniversalInvertedResidual: 3-4         [1, 192, 28, 28]          123,168\n",
       "│    │    └─UniversalInvertedResidual: 3-5         [1, 192, 28, 28]          307,392\n",
       "│    │    └─UniversalInvertedResidual: 3-6         [1, 192, 28, 28]          307,392\n",
       "│    │    └─UniversalInvertedResidual: 3-7         [1, 192, 28, 28]          307,392\n",
       "│    │    └─UniversalInvertedResidual: 3-8         [1, 192, 28, 28]          319,680\n",
       "│    │    └─UniversalInvertedResidual: 3-9         [1, 192, 28, 28]          310,464\n",
       "│    │    └─UniversalInvertedResidual: 3-10        [1, 192, 28, 28]          310,464\n",
       "│    │    └─UniversalInvertedResidual: 3-11        [1, 192, 28, 28]          310,464\n",
       "│    │    └─UniversalInvertedResidual: 3-12        [1, 192, 28, 28]          310,464\n",
       "│    │    └─UniversalInvertedResidual: 3-13        [1, 192, 28, 28]          310,464\n",
       "│    │    └─UniversalInvertedResidual: 3-14        [1, 192, 28, 28]          298,944\n",
       "│    └─Sequential: 2-6                             [1, 512, 14, 14]          --\n",
       "│    │    └─UniversalInvertedResidual: 3-15        [1, 512, 14, 14]          569,152\n",
       "│    │    └─UniversalInvertedResidual: 3-16        [1, 512, 14, 14]          2,171,392\n",
       "│    │    └─UniversalInvertedResidual: 3-17        [1, 512, 14, 14]          2,171,392\n",
       "│    │    └─UniversalInvertedResidual: 3-18        [1, 512, 14, 14]          2,171,392\n",
       "│    │    └─UniversalInvertedResidual: 3-19        [1, 512, 14, 14]          2,116,096\n",
       "│    │    └─UniversalInvertedResidual: 3-20        [1, 512, 14, 14]          2,138,624\n",
       "│    │    └─UniversalInvertedResidual: 3-21        [1, 512, 14, 14]          2,116,096\n",
       "│    │    └─UniversalInvertedResidual: 3-22        [1, 512, 14, 14]          2,116,096\n",
       "│    │    └─UniversalInvertedResidual: 3-23        [1, 512, 14, 14]          2,138,624\n",
       "│    │    └─UniversalInvertedResidual: 3-24        [1, 512, 14, 14]          2,171,392\n",
       "│    │    └─UniversalInvertedResidual: 3-25        [1, 512, 14, 14]          2,116,096\n",
       "│    │    └─UniversalInvertedResidual: 3-26        [1, 512, 14, 14]          2,116,096\n",
       "│    │    └─UniversalInvertedResidual: 3-27        [1, 512, 14, 14]          2,116,096\n",
       "│    └─Sequential: 2-7                             [1, 960, 14, 14]          --\n",
       "│    │    └─ConvBnAct: 3-28                        [1, 960, 14, 14]          493,440\n",
       "├─SelectAdaptivePool2d: 1-4                        [1, 960, 1, 1]            --\n",
       "│    └─AdaptiveAvgPool2d: 2-8                      [1, 960, 1, 1]            --\n",
       "│    └─Identity: 2-9                               [1, 960, 1, 1]            --\n",
       "├─Conv2d: 1-5                                      [1, 1280, 1, 1]           1,228,800\n",
       "├─BatchNormAct2d: 1-6                              [1, 1280, 1, 1]           2,560\n",
       "│    └─Identity: 2-10                              [1, 1280, 1, 1]           --\n",
       "│    └─ReLU: 2-11                                  [1, 1280, 1, 1]           --\n",
       "├─Identity: 1-7                                    [1, 1280, 1, 1]           --\n",
       "├─Flatten: 1-8                                     [1, 1280]                 --\n",
       "├─Linear: 1-9                                      [1, 1000]                 1,281,000\n",
       "====================================================================================================\n",
       "Total params: 32,590,864\n",
       "Trainable params: 32,590,864\n",
       "Non-trainable params: 0\n",
       "Total mult-adds (Units.GIGABYTES): 9.54\n",
       "====================================================================================================\n",
       "Input size (MB): 2.41\n",
       "Forward/backward pass size (MB): 351.55\n",
       "Params size (MB): 129.77\n",
       "Estimated Total Size (MB): 483.73\n",
       "===================================================================================================="
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_config = timm.data.resolve_model_data_config(model)\n",
    "transforms = timm.data.create_transform(**data_config, is_training=False)\n",
    "# transforms\n",
    "summary(model, input_size = (1, 3, 448, 448))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3a0e7e24",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "InterpretableMobileNetV4(\n",
       "  (model): MobileNetV3(\n",
       "    (conv_stem): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bn1): BatchNormAct2d(\n",
       "      24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "      (drop): Identity()\n",
       "      (act): ReLU(inplace=True)\n",
       "    )\n",
       "    (blocks): Sequential(\n",
       "      (0): Sequential(\n",
       "        (0): EdgeResidual(\n",
       "          (conv_exp): Conv2d(24, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "          (aa): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          (se): Identity()\n",
       "          (conv_pwl): Conv2d(96, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(48, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(192, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=192, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "            (aa): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (1): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "      (2): Sequential(\n",
       "        (0): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(384, 384, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=384, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "            (aa): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(384, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (1): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (2): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (3): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (4): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(768, 768, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=768, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (5): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(192, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=192, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (6): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(192, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=192, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (7): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(192, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=192, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (8): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(192, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=192, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (9): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(192, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=192, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (10): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): Identity()\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "      (3): Sequential(\n",
       "        (0): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(192, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=192, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(768, 768, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=768, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "            (aa): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(768, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (1): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 2048, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2048, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (2): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 2048, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2048, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (3): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 2048, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2048, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (4): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): Identity()\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (5): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2048, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (6): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): Identity()\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (7): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): Identity()\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (8): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2048, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (9): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 2048, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2048, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (10): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): Identity()\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (11): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): Identity()\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (12): UniversalInvertedResidual(\n",
       "          (dw_start): ConvNormAct(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512, bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (pw_exp): ConvNormAct(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): ReLU(inplace=True)\n",
       "            )\n",
       "          )\n",
       "          (dw_mid): Identity()\n",
       "          (se): Identity()\n",
       "          (pw_proj): ConvNormAct(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNormAct2d(\n",
       "              512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "              (drop): Identity()\n",
       "              (act): Identity()\n",
       "            )\n",
       "          )\n",
       "          (dw_end): Identity()\n",
       "          (layer_scale): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "      (4): Sequential(\n",
       "        (0): ConvBnAct(\n",
       "          (conv): Conv2d(512, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): ReLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Identity())\n",
       "    (conv_head): Conv2d(960, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "    (norm_head): BatchNormAct2d(\n",
       "      1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "      (drop): Identity()\n",
       "      (act): ReLU(inplace=True)\n",
       "    )\n",
       "    (act2): Identity()\n",
       "    (flatten): Flatten(start_dim=1, end_dim=-1)\n",
       "    (classifier): Linear(in_features=1280, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4b2be1a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.randn(1, 3, 384, 384)\n",
    "b = torch.randn(1, 3, 300, 300)\n",
    "c = torch.randn(1, 3, 448, 448)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "95fd2921",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Compose(\n",
       "    Resize(size=471, interpolation=bicubic, max_size=None, antialias=True)\n",
       "    CenterCrop(size=(448, 448))\n",
       "    MaybeToTensor()\n",
       "    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = InterpretableMobileNetV4()\n",
    "# print(\"Explanation Model:\", torch.topk(model(c, explanation_mode = True, explanation_mask = torch.ones_like(c)).softmax(dim = 1), k=5))\n",
    "# print(\"Original Model:\", torch.topk(model(c, explanation_mode = False).softmax(dim = 1), k=5))\n",
    "model.transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e2344ba5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "InterpretableEfficientNetV2(\n",
       "  (model): EfficientNet(\n",
       "    (conv_stem): Conv2dSame(3, 24, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
       "    (bn1): BatchNormAct2d(\n",
       "      24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "      (drop): Identity()\n",
       "      (act): SiLU(inplace=True)\n",
       "    )\n",
       "    (blocks): Sequential(\n",
       "      (0): Sequential(\n",
       "        (0): ConvBnAct(\n",
       "          (conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (1): ConvBnAct(\n",
       "          (conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): EdgeResidual(\n",
       "          (conv_exp): Conv2dSame(24, 96, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): Identity()\n",
       "          (conv_pwl): Conv2d(96, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (1): EdgeResidual(\n",
       "          (conv_exp): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): Identity()\n",
       "          (conv_pwl): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (2): EdgeResidual(\n",
       "          (conv_exp): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): Identity()\n",
       "          (conv_pwl): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (3): EdgeResidual(\n",
       "          (conv_exp): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): Identity()\n",
       "          (conv_pwl): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "      (2): Sequential(\n",
       "        (0): EdgeResidual(\n",
       "          (conv_exp): Conv2dSame(48, 192, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): Identity()\n",
       "          (conv_pwl): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (1): EdgeResidual(\n",
       "          (conv_exp): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): Identity()\n",
       "          (conv_pwl): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (2): EdgeResidual(\n",
       "          (conv_exp): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): Identity()\n",
       "          (conv_pwl): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (3): EdgeResidual(\n",
       "          (conv_exp): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): Identity()\n",
       "          (conv_pwl): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "      (3): Sequential(\n",
       "        (0): InvertedResidual(\n",
       "          (conv_pw): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2dSame(256, 256, kernel_size=(3, 3), stride=(2, 2), groups=256, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(256, 16, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (1): InvertedResidual(\n",
       "          (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (2): InvertedResidual(\n",
       "          (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (3): InvertedResidual(\n",
       "          (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (4): InvertedResidual(\n",
       "          (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (5): InvertedResidual(\n",
       "          (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "      (4): Sequential(\n",
       "        (0): InvertedResidual(\n",
       "          (conv_pw): Conv2d(128, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            768, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            768, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(768, 32, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(32, 768, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (1): InvertedResidual(\n",
       "          (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (2): InvertedResidual(\n",
       "          (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (3): InvertedResidual(\n",
       "          (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (4): InvertedResidual(\n",
       "          (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (5): InvertedResidual(\n",
       "          (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (6): InvertedResidual(\n",
       "          (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (7): InvertedResidual(\n",
       "          (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (8): InvertedResidual(\n",
       "          (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "      (5): Sequential(\n",
       "        (0): InvertedResidual(\n",
       "          (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2dSame(960, 960, kernel_size=(3, 3), stride=(2, 2), groups=960, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(960, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (1): InvertedResidual(\n",
       "          (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (2): InvertedResidual(\n",
       "          (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (3): InvertedResidual(\n",
       "          (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (4): InvertedResidual(\n",
       "          (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (5): InvertedResidual(\n",
       "          (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (6): InvertedResidual(\n",
       "          (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (7): InvertedResidual(\n",
       "          (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (8): InvertedResidual(\n",
       "          (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (9): InvertedResidual(\n",
       "          (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (10): InvertedResidual(\n",
       "          (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (11): InvertedResidual(\n",
       "          (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (12): InvertedResidual(\n",
       "          (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (13): InvertedResidual(\n",
       "          (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (14): InvertedResidual(\n",
       "          (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn1): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "          (bn2): BatchNormAct2d(\n",
       "            1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): SiLU(inplace=True)\n",
       "          )\n",
       "          (aa): Identity()\n",
       "          (se): SqueezeExcite(\n",
       "            (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (act1): SiLU(inplace=True)\n",
       "            (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "            (gate): Sigmoid()\n",
       "          )\n",
       "          (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn3): BatchNormAct2d(\n",
       "            256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "            (drop): Identity()\n",
       "            (act): Identity()\n",
       "          )\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (conv_head): Conv2d(256, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "    (bn2): BatchNormAct2d(\n",
       "      1280, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "      (drop): Identity()\n",
       "      (act): SiLU(inplace=True)\n",
       "    )\n",
       "    (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
       "    (classifier): Linear(in_features=1280, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = InterpretableEfficientNetV2()\n",
    "# print(\"Explanation Model:\", torch.topk(model(b, explanation_mode = True, explanation_mask = torch.ones_like(b)).softmax(dim = 1), k=5))\n",
    "# print(\"Original Model:\", torch.topk(model(b, explanation_mode = False).softmax(dim = 1), k=5))\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bf9aa120",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Compose(\n",
       "    Resize(size=235, interpolation=bicubic, max_size=None, antialias=True)\n",
       "    CenterCrop(size=(224, 224))\n",
       "    MaybeToTensor()\n",
       "    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = InterpretableRegNetY()\n",
    "# print(\"Explanation Model:\", torch.topk(model(a, explanation_mode = True, explanation_mask = torch.ones_like(a)).softmax(dim = 1), k=5))\n",
    "# print(\"Original Model:\", torch.topk(model(a, explanation_mode = False).softmax(dim = 1), k=5))\n",
    "model.transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f9cf277d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Compose(\n",
       "    Resize(size=235, interpolation=bicubic, max_size=None, antialias=True)\n",
       "    CenterCrop(size=(224, 224))\n",
       "    MaybeToTensor()\n",
       "    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = timm.create_model('regnety_120.sw_in12k_ft_in1k', pretrained=True)\n",
    "data_config = timm.data.resolve_model_data_config(model)\n",
    "transforms = timm.data.create_transform(**data_config, is_training=False)\n",
    "transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3324418",
   "metadata": {},
   "source": [
    "# ConvNext-V2 Tiny "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2857079a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ConvNeXt(\n",
       "  (stem): Sequential(\n",
       "    (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))\n",
       "    (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)\n",
       "  )\n",
       "  (stages): Sequential(\n",
       "    (0): ConvNeXtStage(\n",
       "      (downsample): Identity()\n",
       "      (blocks): Sequential(\n",
       "        (0): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)\n",
       "          (norm): LayerNorm((96,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=96, out_features=384, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=384, out_features=96, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (1): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)\n",
       "          (norm): LayerNorm((96,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=96, out_features=384, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=384, out_features=96, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (2): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)\n",
       "          (norm): LayerNorm((96,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=96, out_features=384, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=384, out_features=96, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (1): ConvNeXtStage(\n",
       "      (downsample): Sequential(\n",
       "        (0): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)\n",
       "        (1): Conv2d(96, 192, kernel_size=(2, 2), stride=(2, 2))\n",
       "      )\n",
       "      (blocks): Sequential(\n",
       "        (0): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=192)\n",
       "          (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (1): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=192)\n",
       "          (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (2): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=192)\n",
       "          (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=192, out_features=768, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=768, out_features=192, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (2): ConvNeXtStage(\n",
       "      (downsample): Sequential(\n",
       "        (0): LayerNorm2d((192,), eps=1e-06, elementwise_affine=True)\n",
       "        (1): Conv2d(192, 384, kernel_size=(2, 2), stride=(2, 2))\n",
       "      )\n",
       "      (blocks): Sequential(\n",
       "        (0): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)\n",
       "          (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=384, out_features=1536, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=1536, out_features=384, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (1): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)\n",
       "          (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=384, out_features=1536, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=1536, out_features=384, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (2): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)\n",
       "          (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=384, out_features=1536, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=1536, out_features=384, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (3): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)\n",
       "          (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=384, out_features=1536, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=1536, out_features=384, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (4): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)\n",
       "          (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=384, out_features=1536, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=1536, out_features=384, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (5): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)\n",
       "          (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=384, out_features=1536, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=1536, out_features=384, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (6): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)\n",
       "          (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=384, out_features=1536, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=1536, out_features=384, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (7): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)\n",
       "          (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=384, out_features=1536, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=1536, out_features=384, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (8): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(384, 384, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=384)\n",
       "          (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=384, out_features=1536, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=1536, out_features=384, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (3): ConvNeXtStage(\n",
       "      (downsample): Sequential(\n",
       "        (0): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)\n",
       "        (1): Conv2d(384, 768, kernel_size=(2, 2), stride=(2, 2))\n",
       "      )\n",
       "      (blocks): Sequential(\n",
       "        (0): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(768, 768, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=768)\n",
       "          (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (1): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(768, 768, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=768)\n",
       "          (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "        (2): ConvNeXtBlock(\n",
       "          (conv_dw): Conv2d(768, 768, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=768)\n",
       "          (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): Mlp(\n",
       "            (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (act): GELU()\n",
       "            (drop1): Dropout(p=0.0, inplace=False)\n",
       "            (norm): Identity()\n",
       "            (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (drop2): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (shortcut): Identity()\n",
       "          (drop_path): Identity()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (norm_pre): Identity()\n",
       "  (head): NormMlpClassifierHead(\n",
       "    (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Identity())\n",
       "    (norm): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)\n",
       "    (flatten): Flatten(start_dim=1, end_dim=-1)\n",
       "    (pre_logits): Identity()\n",
       "    (drop): Dropout(p=0.0, inplace=False)\n",
       "    (fc): Linear(in_features=768, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = timm.create_model('convnext_tiny.in12k_ft_in1k_384', pretrained=True)\n",
    "model.eval()\n",
    "\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8eb11239",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([ 0.8621,  1.0423,  1.2021,  0.7648,  0.8156, -0.6535,  1.0888,  0.4848,\n",
       "        -0.7411, -0.3887,  0.5110,  0.4774,  0.7412,  1.8590,  1.0018, -1.0614,\n",
       "         0.7098,  0.8344,  0.4168, -0.6064, -1.0316, -0.8434, -1.1538,  0.6827,\n",
       "         0.6306,  0.9008, -0.3564, -0.7322, -1.2152, -0.6058,  0.9057, -0.8514,\n",
       "         1.7696, -0.8580,  0.4737, -0.7703, -0.4138, -0.7393,  0.3107, -1.4632,\n",
       "        -1.1890, -0.6273, -0.7599,  0.7131, -0.4251, -0.9467,  0.3472,  1.3547,\n",
       "        -1.0745, -0.5855,  0.3670, -0.9164,  0.8657,  0.8809,  0.8392,  0.4621,\n",
       "        -0.8520, -0.4056,  0.2771, -0.2119,  1.1496, -0.5005,  0.3667,  0.4974,\n",
       "        -0.6973,  0.7200, -1.2957,  0.8765,  0.7978,  1.0101, -0.5039, -0.9359,\n",
       "         1.0101, -0.3791,  0.4655,  1.1019,  1.0709,  0.9069, -1.9776, -0.8656,\n",
       "         0.4127,  0.1739, -1.6337,  0.8025, -0.7882, -1.0101,  0.8860, -0.3181,\n",
       "        -0.6249, -0.7340,  1.1282, -0.5259, -0.7605, -0.5230, -0.1734, -0.2077],\n",
       "       requires_grad=True)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.model.stages[0].blocks[0].gamma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "87b7704c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(model.stages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f005ef93",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.ones(3, 384, 384)\n",
    "model(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbd11836",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Compose(\n",
       "    Resize(size=(384, 384), interpolation=bicubic, max_size=None, antialias=True)\n",
       "    CenterCrop(size=(384, 384))\n",
       "    MaybeToTensor()\n",
       "    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))\n",
       ")"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_config = timm.data.resolve_model_data_config(model)\n",
    "transforms = timm.data.create_transform(**data_config, is_training=False)\n",
    "\n",
    "transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0f96e32",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "====================================================================================================\n",
       "Layer (type:depth-idx)                             Output Shape              Param #\n",
       "====================================================================================================\n",
       "ConvNeXt                                           [15, 1000]                --\n",
       "├─Sequential: 1-1                                  [15, 96, 96, 96]          --\n",
       "│    └─Conv2d: 2-1                                 [15, 96, 96, 96]          4,704\n",
       "│    └─LayerNorm2d: 2-2                            [15, 96, 96, 96]          192\n",
       "├─Sequential: 1-2                                  [15, 768, 12, 12]         --\n",
       "│    └─ConvNeXtStage: 2-3                          [15, 96, 96, 96]          --\n",
       "│    │    └─Identity: 3-1                          [15, 96, 96, 96]          --\n",
       "│    │    └─Sequential: 3-2                        [15, 96, 96, 96]          237,888\n",
       "│    └─ConvNeXtStage: 2-4                          [15, 192, 48, 48]         --\n",
       "│    │    └─Sequential: 3-3                        [15, 192, 48, 48]         74,112\n",
       "│    │    └─Sequential: 3-4                        [15, 192, 48, 48]         918,144\n",
       "│    └─ConvNeXtStage: 2-5                          [15, 384, 24, 24]         --\n",
       "│    │    └─Sequential: 3-5                        [15, 384, 24, 24]         295,680\n",
       "│    │    └─Sequential: 3-6                        [15, 384, 24, 24]         10,817,280\n",
       "│    └─ConvNeXtStage: 2-6                          [15, 768, 12, 12]         --\n",
       "│    │    └─Sequential: 3-7                        [15, 768, 12, 12]         1,181,184\n",
       "│    │    └─Sequential: 3-8                        [15, 768, 12, 12]         14,289,408\n",
       "├─Identity: 1-3                                    [15, 768, 12, 12]         --\n",
       "├─NormMlpClassifierHead: 1-4                       [15, 1000]                --\n",
       "│    └─SelectAdaptivePool2d: 2-7                   [15, 768, 1, 1]           --\n",
       "│    │    └─AdaptiveAvgPool2d: 3-9                 [15, 768, 1, 1]           --\n",
       "│    │    └─Identity: 3-10                         [15, 768, 1, 1]           --\n",
       "│    └─LayerNorm2d: 2-8                            [15, 768, 1, 1]           1,536\n",
       "│    └─Flatten: 2-9                                [15, 768]                 --\n",
       "│    └─Identity: 2-10                              [15, 768]                 --\n",
       "│    └─Dropout: 2-11                               [15, 768]                 --\n",
       "│    └─Linear: 2-12                                [15, 1000]                769,000\n",
       "====================================================================================================\n",
       "Total params: 28,589,128\n",
       "Trainable params: 28,589,128\n",
       "Non-trainable params: 0\n",
       "Total mult-adds (Units.GIGABYTES): 13.43\n",
       "====================================================================================================\n",
       "Input size (MB): 26.54\n",
       "Forward/backward pass size (MB): 5786.39\n",
       "Params size (MB): 114.33\n",
       "Estimated Total Size (MB): 5927.26\n",
       "===================================================================================================="
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_size = 15\n",
    "summary(model, input_size = (batch_size, 3, 384, 384))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13c0b8eb",
   "metadata": {},
   "source": [
    "# RegNet-Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5d2c57ef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RegNet(\n",
       "  (stem): ConvNormAct(\n",
       "    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNormAct2d(\n",
       "      32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "      (drop): Identity()\n",
       "      (act): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (s1): RegStage(\n",
       "    (b1): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(32, 224, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(224, 224, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=2, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(224, 8, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(8, 224, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(224, 224, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): ConvNormAct(\n",
       "        (conv): Conv2d(32, 224, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b2): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(224, 224, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(224, 56, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(56, 224, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(224, 224, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "  )\n",
       "  (s2): RegStage(\n",
       "    (b1): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(224, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(448, 448, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=4, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(448, 56, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(56, 448, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(448, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): ConvNormAct(\n",
       "        (conv): Conv2d(224, 448, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b2): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(448, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(448, 448, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(448, 112, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(112, 448, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(448, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b3): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(448, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(448, 448, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(448, 112, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(112, 448, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(448, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b4): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(448, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(448, 448, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(448, 112, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(112, 448, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(448, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b5): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(448, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(448, 448, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(448, 112, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(112, 448, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(448, 448, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "  )\n",
       "  (s3): RegStage(\n",
       "    (b1): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(448, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=8, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(896, 112, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(112, 896, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): ConvNormAct(\n",
       "        (conv): Conv2d(448, 896, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b2): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(896, 224, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(224, 896, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b3): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(896, 224, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(224, 896, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b4): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(896, 224, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(224, 896, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b5): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(896, 224, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(224, 896, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b6): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(896, 224, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(224, 896, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b7): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(896, 224, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(224, 896, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b8): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(896, 224, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(224, 896, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b9): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(896, 224, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(224, 896, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b10): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(896, 224, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(224, 896, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "    (b11): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(896, 224, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(224, 896, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(896, 896, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): Identity()\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "  )\n",
       "  (s4): RegStage(\n",
       "    (b1): Bottleneck(\n",
       "      (conv1): ConvNormAct(\n",
       "        (conv): Conv2d(896, 2240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          2240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (conv2): ConvNormAct(\n",
       "        (conv): Conv2d(2240, 2240, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=20, bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          2240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): ReLU(inplace=True)\n",
       "        )\n",
       "      )\n",
       "      (se): SEModule(\n",
       "        (fc1): Conv2d(2240, 224, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (bn): Identity()\n",
       "        (act): ReLU(inplace=True)\n",
       "        (fc2): Conv2d(224, 2240, kernel_size=(1, 1), stride=(1, 1))\n",
       "        (gate): Sigmoid()\n",
       "      )\n",
       "      (conv3): ConvNormAct(\n",
       "        (conv): Conv2d(2240, 2240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          2240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (act3): ReLU()\n",
       "      (downsample): ConvNormAct(\n",
       "        (conv): Conv2d(896, 2240, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (bn): BatchNormAct2d(\n",
       "          2240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "      )\n",
       "      (drop_path): Identity()\n",
       "    )\n",
       "  )\n",
       "  (final_conv): Identity()\n",
       "  (head): ClassifierHead(\n",
       "    (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
       "    (drop): Dropout(p=0.0, inplace=False)\n",
       "    (fc): Linear(in_features=2240, out_features=1000, bias=True)\n",
       "    (flatten): Identity()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = timm.create_model('regnety_120.sw_in12k_ft_in1k', pretrained=True)\n",
    "model.eval()\n",
    "\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c87cfd83",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(model.s1.children())[0].se.add_maxpool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "af804b79",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(list(model.s4.children()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "90db12f1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "===============================================================================================\n",
       "Layer (type:depth-idx)                        Output Shape              Param #\n",
       "===============================================================================================\n",
       "RegNet                                        [1, 1000]                 --\n",
       "├─ConvNormAct: 1-1                            [1, 32, 144, 144]         --\n",
       "│    └─Conv2d: 2-1                            [1, 32, 144, 144]         864\n",
       "│    └─BatchNormAct2d: 2-2                    [1, 32, 144, 144]         64\n",
       "│    │    └─Identity: 3-1                     [1, 32, 144, 144]         --\n",
       "│    │    └─ReLU: 3-2                         [1, 32, 144, 144]         --\n",
       "├─RegStage: 1-2                               [1, 224, 72, 72]          --\n",
       "│    └─Bottleneck: 2-3                        [1, 224, 72, 72]          --\n",
       "│    │    └─ConvNormAct: 3-3                  [1, 224, 144, 144]        7,616\n",
       "│    │    └─ConvNormAct: 3-4                  [1, 224, 72, 72]          226,240\n",
       "│    │    └─SEModule: 3-5                     [1, 224, 72, 72]          3,816\n",
       "│    │    └─ConvNormAct: 3-6                  [1, 224, 72, 72]          50,624\n",
       "│    │    └─Identity: 3-7                     [1, 224, 72, 72]          --\n",
       "│    │    └─ConvNormAct: 3-8                  [1, 224, 72, 72]          7,616\n",
       "│    │    └─ReLU: 3-9                         [1, 224, 72, 72]          --\n",
       "│    └─Bottleneck: 2-4                        [1, 224, 72, 72]          --\n",
       "│    │    └─ConvNormAct: 3-10                 [1, 224, 72, 72]          50,624\n",
       "│    │    └─ConvNormAct: 3-11                 [1, 224, 72, 72]          226,240\n",
       "│    │    └─SEModule: 3-12                    [1, 224, 72, 72]          25,368\n",
       "│    │    └─ConvNormAct: 3-13                 [1, 224, 72, 72]          50,624\n",
       "│    │    └─Identity: 3-14                    [1, 224, 72, 72]          --\n",
       "│    │    └─Identity: 3-15                    [1, 224, 72, 72]          --\n",
       "│    │    └─ReLU: 3-16                        [1, 224, 72, 72]          --\n",
       "├─RegStage: 1-3                               [1, 448, 36, 36]          --\n",
       "│    └─Bottleneck: 2-5                        [1, 448, 36, 36]          --\n",
       "│    │    └─ConvNormAct: 3-17                 [1, 448, 72, 72]          101,248\n",
       "│    │    └─ConvNormAct: 3-18                 [1, 448, 36, 36]          452,480\n",
       "│    │    └─SEModule: 3-19                    [1, 448, 36, 36]          50,680\n",
       "│    │    └─ConvNormAct: 3-20                 [1, 448, 36, 36]          201,600\n",
       "│    │    └─Identity: 3-21                    [1, 448, 36, 36]          --\n",
       "│    │    └─ConvNormAct: 3-22                 [1, 448, 36, 36]          101,248\n",
       "│    │    └─ReLU: 3-23                        [1, 448, 36, 36]          --\n",
       "│    └─Bottleneck: 2-6                        [1, 448, 36, 36]          --\n",
       "│    │    └─ConvNormAct: 3-24                 [1, 448, 36, 36]          201,600\n",
       "│    │    └─ConvNormAct: 3-25                 [1, 448, 36, 36]          452,480\n",
       "│    │    └─SEModule: 3-26                    [1, 448, 36, 36]          100,912\n",
       "│    │    └─ConvNormAct: 3-27                 [1, 448, 36, 36]          201,600\n",
       "│    │    └─Identity: 3-28                    [1, 448, 36, 36]          --\n",
       "│    │    └─Identity: 3-29                    [1, 448, 36, 36]          --\n",
       "│    │    └─ReLU: 3-30                        [1, 448, 36, 36]          --\n",
       "│    └─Bottleneck: 2-7                        [1, 448, 36, 36]          --\n",
       "│    │    └─ConvNormAct: 3-31                 [1, 448, 36, 36]          201,600\n",
       "│    │    └─ConvNormAct: 3-32                 [1, 448, 36, 36]          452,480\n",
       "│    │    └─SEModule: 3-33                    [1, 448, 36, 36]          100,912\n",
       "│    │    └─ConvNormAct: 3-34                 [1, 448, 36, 36]          201,600\n",
       "│    │    └─Identity: 3-35                    [1, 448, 36, 36]          --\n",
       "│    │    └─Identity: 3-36                    [1, 448, 36, 36]          --\n",
       "│    │    └─ReLU: 3-37                        [1, 448, 36, 36]          --\n",
       "│    └─Bottleneck: 2-8                        [1, 448, 36, 36]          --\n",
       "│    │    └─ConvNormAct: 3-38                 [1, 448, 36, 36]          201,600\n",
       "│    │    └─ConvNormAct: 3-39                 [1, 448, 36, 36]          452,480\n",
       "│    │    └─SEModule: 3-40                    [1, 448, 36, 36]          100,912\n",
       "│    │    └─ConvNormAct: 3-41                 [1, 448, 36, 36]          201,600\n",
       "│    │    └─Identity: 3-42                    [1, 448, 36, 36]          --\n",
       "│    │    └─Identity: 3-43                    [1, 448, 36, 36]          --\n",
       "│    │    └─ReLU: 3-44                        [1, 448, 36, 36]          --\n",
       "│    └─Bottleneck: 2-9                        [1, 448, 36, 36]          --\n",
       "│    │    └─ConvNormAct: 3-45                 [1, 448, 36, 36]          201,600\n",
       "│    │    └─ConvNormAct: 3-46                 [1, 448, 36, 36]          452,480\n",
       "│    │    └─SEModule: 3-47                    [1, 448, 36, 36]          100,912\n",
       "│    │    └─ConvNormAct: 3-48                 [1, 448, 36, 36]          201,600\n",
       "│    │    └─Identity: 3-49                    [1, 448, 36, 36]          --\n",
       "│    │    └─Identity: 3-50                    [1, 448, 36, 36]          --\n",
       "│    │    └─ReLU: 3-51                        [1, 448, 36, 36]          --\n",
       "├─RegStage: 1-4                               [1, 896, 18, 18]          --\n",
       "│    └─Bottleneck: 2-10                       [1, 896, 18, 18]          --\n",
       "│    │    └─ConvNormAct: 3-52                 [1, 896, 36, 36]          403,200\n",
       "│    │    └─ConvNormAct: 3-53                 [1, 896, 18, 18]          904,960\n",
       "│    │    └─SEModule: 3-54                    [1, 896, 18, 18]          201,712\n",
       "│    │    └─ConvNormAct: 3-55                 [1, 896, 18, 18]          804,608\n",
       "│    │    └─Identity: 3-56                    [1, 896, 18, 18]          --\n",
       "│    │    └─ConvNormAct: 3-57                 [1, 896, 18, 18]          403,200\n",
       "│    │    └─ReLU: 3-58                        [1, 896, 18, 18]          --\n",
       "│    └─Bottleneck: 2-11                       [1, 896, 18, 18]          --\n",
       "│    │    └─ConvNormAct: 3-59                 [1, 896, 18, 18]          804,608\n",
       "│    │    └─ConvNormAct: 3-60                 [1, 896, 18, 18]          904,960\n",
       "│    │    └─SEModule: 3-61                    [1, 896, 18, 18]          402,528\n",
       "│    │    └─ConvNormAct: 3-62                 [1, 896, 18, 18]          804,608\n",
       "│    │    └─Identity: 3-63                    [1, 896, 18, 18]          --\n",
       "│    │    └─Identity: 3-64                    [1, 896, 18, 18]          --\n",
       "│    │    └─ReLU: 3-65                        [1, 896, 18, 18]          --\n",
       "│    └─Bottleneck: 2-12                       [1, 896, 18, 18]          --\n",
       "│    │    └─ConvNormAct: 3-66                 [1, 896, 18, 18]          804,608\n",
       "│    │    └─ConvNormAct: 3-67                 [1, 896, 18, 18]          904,960\n",
       "│    │    └─SEModule: 3-68                    [1, 896, 18, 18]          402,528\n",
       "│    │    └─ConvNormAct: 3-69                 [1, 896, 18, 18]          804,608\n",
       "│    │    └─Identity: 3-70                    [1, 896, 18, 18]          --\n",
       "│    │    └─Identity: 3-71                    [1, 896, 18, 18]          --\n",
       "│    │    └─ReLU: 3-72                        [1, 896, 18, 18]          --\n",
       "│    └─Bottleneck: 2-13                       [1, 896, 18, 18]          --\n",
       "│    │    └─ConvNormAct: 3-73                 [1, 896, 18, 18]          804,608\n",
       "│    │    └─ConvNormAct: 3-74                 [1, 896, 18, 18]          904,960\n",
       "│    │    └─SEModule: 3-75                    [1, 896, 18, 18]          402,528\n",
       "│    │    └─ConvNormAct: 3-76                 [1, 896, 18, 18]          804,608\n",
       "│    │    └─Identity: 3-77                    [1, 896, 18, 18]          --\n",
       "│    │    └─Identity: 3-78                    [1, 896, 18, 18]          --\n",
       "│    │    └─ReLU: 3-79                        [1, 896, 18, 18]          --\n",
       "│    └─Bottleneck: 2-14                       [1, 896, 18, 18]          --\n",
       "│    │    └─ConvNormAct: 3-80                 [1, 896, 18, 18]          804,608\n",
       "│    │    └─ConvNormAct: 3-81                 [1, 896, 18, 18]          904,960\n",
       "│    │    └─SEModule: 3-82                    [1, 896, 18, 18]          402,528\n",
       "│    │    └─ConvNormAct: 3-83                 [1, 896, 18, 18]          804,608\n",
       "│    │    └─Identity: 3-84                    [1, 896, 18, 18]          --\n",
       "│    │    └─Identity: 3-85                    [1, 896, 18, 18]          --\n",
       "│    │    └─ReLU: 3-86                        [1, 896, 18, 18]          --\n",
       "│    └─Bottleneck: 2-15                       [1, 896, 18, 18]          --\n",
       "│    │    └─ConvNormAct: 3-87                 [1, 896, 18, 18]          804,608\n",
       "│    │    └─ConvNormAct: 3-88                 [1, 896, 18, 18]          904,960\n",
       "│    │    └─SEModule: 3-89                    [1, 896, 18, 18]          402,528\n",
       "│    │    └─ConvNormAct: 3-90                 [1, 896, 18, 18]          804,608\n",
       "│    │    └─Identity: 3-91                    [1, 896, 18, 18]          --\n",
       "│    │    └─Identity: 3-92                    [1, 896, 18, 18]          --\n",
       "│    │    └─ReLU: 3-93                        [1, 896, 18, 18]          --\n",
       "│    └─Bottleneck: 2-16                       [1, 896, 18, 18]          --\n",
       "│    │    └─ConvNormAct: 3-94                 [1, 896, 18, 18]          804,608\n",
       "│    │    └─ConvNormAct: 3-95                 [1, 896, 18, 18]          904,960\n",
       "│    │    └─SEModule: 3-96                    [1, 896, 18, 18]          402,528\n",
       "│    │    └─ConvNormAct: 3-97                 [1, 896, 18, 18]          804,608\n",
       "│    │    └─Identity: 3-98                    [1, 896, 18, 18]          --\n",
       "│    │    └─Identity: 3-99                    [1, 896, 18, 18]          --\n",
       "│    │    └─ReLU: 3-100                       [1, 896, 18, 18]          --\n",
       "│    └─Bottleneck: 2-17                       [1, 896, 18, 18]          --\n",
       "│    │    └─ConvNormAct: 3-101                [1, 896, 18, 18]          804,608\n",
       "│    │    └─ConvNormAct: 3-102                [1, 896, 18, 18]          904,960\n",
       "│    │    └─SEModule: 3-103                   [1, 896, 18, 18]          402,528\n",
       "│    │    └─ConvNormAct: 3-104                [1, 896, 18, 18]          804,608\n",
       "│    │    └─Identity: 3-105                   [1, 896, 18, 18]          --\n",
       "│    │    └─Identity: 3-106                   [1, 896, 18, 18]          --\n",
       "│    │    └─ReLU: 3-107                       [1, 896, 18, 18]          --\n",
       "│    └─Bottleneck: 2-18                       [1, 896, 18, 18]          --\n",
       "│    │    └─ConvNormAct: 3-108                [1, 896, 18, 18]          804,608\n",
       "│    │    └─ConvNormAct: 3-109                [1, 896, 18, 18]          904,960\n",
       "│    │    └─SEModule: 3-110                   [1, 896, 18, 18]          402,528\n",
       "│    │    └─ConvNormAct: 3-111                [1, 896, 18, 18]          804,608\n",
       "│    │    └─Identity: 3-112                   [1, 896, 18, 18]          --\n",
       "│    │    └─Identity: 3-113                   [1, 896, 18, 18]          --\n",
       "│    │    └─ReLU: 3-114                       [1, 896, 18, 18]          --\n",
       "│    └─Bottleneck: 2-19                       [1, 896, 18, 18]          --\n",
       "│    │    └─ConvNormAct: 3-115                [1, 896, 18, 18]          804,608\n",
       "│    │    └─ConvNormAct: 3-116                [1, 896, 18, 18]          904,960\n",
       "│    │    └─SEModule: 3-117                   [1, 896, 18, 18]          402,528\n",
       "│    │    └─ConvNormAct: 3-118                [1, 896, 18, 18]          804,608\n",
       "│    │    └─Identity: 3-119                   [1, 896, 18, 18]          --\n",
       "│    │    └─Identity: 3-120                   [1, 896, 18, 18]          --\n",
       "│    │    └─ReLU: 3-121                       [1, 896, 18, 18]          --\n",
       "│    └─Bottleneck: 2-20                       [1, 896, 18, 18]          --\n",
       "│    │    └─ConvNormAct: 3-122                [1, 896, 18, 18]          804,608\n",
       "│    │    └─ConvNormAct: 3-123                [1, 896, 18, 18]          904,960\n",
       "│    │    └─SEModule: 3-124                   [1, 896, 18, 18]          402,528\n",
       "│    │    └─ConvNormAct: 3-125                [1, 896, 18, 18]          804,608\n",
       "│    │    └─Identity: 3-126                   [1, 896, 18, 18]          --\n",
       "│    │    └─Identity: 3-127                   [1, 896, 18, 18]          --\n",
       "│    │    └─ReLU: 3-128                       [1, 896, 18, 18]          --\n",
       "├─RegStage: 1-5                               [1, 2240, 9, 9]           --\n",
       "│    └─Bottleneck: 2-21                       [1, 2240, 9, 9]           --\n",
       "│    │    └─ConvNormAct: 3-129                [1, 2240, 18, 18]         2,011,520\n",
       "│    │    └─ConvNormAct: 3-130                [1, 2240, 9, 9]           2,262,400\n",
       "│    │    └─SEModule: 3-131                   [1, 2240, 9, 9]           1,005,984\n",
       "│    │    └─ConvNormAct: 3-132                [1, 2240, 9, 9]           5,022,080\n",
       "│    │    └─Identity: 3-133                   [1, 2240, 9, 9]           --\n",
       "│    │    └─ConvNormAct: 3-134                [1, 2240, 9, 9]           2,011,520\n",
       "│    │    └─ReLU: 3-135                       [1, 2240, 9, 9]           --\n",
       "├─Identity: 1-6                               [1, 2240, 9, 9]           --\n",
       "├─ClassifierHead: 1-7                         [1, 1000]                 --\n",
       "│    └─SelectAdaptivePool2d: 2-22             [1, 2240]                 --\n",
       "│    │    └─AdaptiveAvgPool2d: 3-136          [1, 2240, 1, 1]           --\n",
       "│    │    └─Flatten: 3-137                    [1, 2240]                 --\n",
       "│    └─Dropout: 2-23                          [1, 2240]                 --\n",
       "│    └─Linear: 2-24                           [1, 1000]                 2,241,000\n",
       "│    └─Identity: 2-25                         [1, 1000]                 --\n",
       "===============================================================================================\n",
       "Total params: 51,822,544\n",
       "Trainable params: 51,822,544\n",
       "Non-trainable params: 0\n",
       "Total mult-adds (Units.GIGABYTES): 19.98\n",
       "===============================================================================================\n",
       "Input size (MB): 1.00\n",
       "Forward/backward pass size (MB): 282.70\n",
       "Params size (MB): 206.90\n",
       "Estimated Total Size (MB): 490.60\n",
       "==============================================================================================="
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summary(model, input_size = (1, 3, 288, 288))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24efd66c",
   "metadata": {},
   "source": [
    "# EfficientNet-V2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e6514a4c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "EfficientNet(\n",
       "  (conv_stem): Conv2dSame(3, 24, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
       "  (bn1): BatchNormAct2d(\n",
       "    24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "    (drop): Identity()\n",
       "    (act): SiLU(inplace=True)\n",
       "  )\n",
       "  (blocks): Sequential(\n",
       "    (0): Sequential(\n",
       "      (0): ConvBnAct(\n",
       "        (conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (1): ConvBnAct(\n",
       "        (conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "    )\n",
       "    (1): Sequential(\n",
       "      (0): EdgeResidual(\n",
       "        (conv_exp): Conv2dSame(24, 96, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): Identity()\n",
       "        (conv_pwl): Conv2d(96, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (1): EdgeResidual(\n",
       "        (conv_exp): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): Identity()\n",
       "        (conv_pwl): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (2): EdgeResidual(\n",
       "        (conv_exp): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): Identity()\n",
       "        (conv_pwl): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (3): EdgeResidual(\n",
       "        (conv_exp): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): Identity()\n",
       "        (conv_pwl): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "    )\n",
       "    (2): Sequential(\n",
       "      (0): EdgeResidual(\n",
       "        (conv_exp): Conv2dSame(48, 192, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): Identity()\n",
       "        (conv_pwl): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (1): EdgeResidual(\n",
       "        (conv_exp): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): Identity()\n",
       "        (conv_pwl): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (2): EdgeResidual(\n",
       "        (conv_exp): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): Identity()\n",
       "        (conv_pwl): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (3): EdgeResidual(\n",
       "        (conv_exp): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): Identity()\n",
       "        (conv_pwl): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "    )\n",
       "    (3): Sequential(\n",
       "      (0): InvertedResidual(\n",
       "        (conv_pw): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2dSame(256, 256, kernel_size=(3, 3), stride=(2, 2), groups=256, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(256, 16, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (1): InvertedResidual(\n",
       "        (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (2): InvertedResidual(\n",
       "        (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (3): InvertedResidual(\n",
       "        (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (4): InvertedResidual(\n",
       "        (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (5): InvertedResidual(\n",
       "        (conv_pw): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "    )\n",
       "    (4): Sequential(\n",
       "      (0): InvertedResidual(\n",
       "        (conv_pw): Conv2d(128, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          768, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          768, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(768, 32, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(32, 768, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(768, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (1): InvertedResidual(\n",
       "        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (2): InvertedResidual(\n",
       "        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (3): InvertedResidual(\n",
       "        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (4): InvertedResidual(\n",
       "        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (5): InvertedResidual(\n",
       "        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (6): InvertedResidual(\n",
       "        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (7): InvertedResidual(\n",
       "        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (8): InvertedResidual(\n",
       "        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "    )\n",
       "    (5): Sequential(\n",
       "      (0): InvertedResidual(\n",
       "        (conv_pw): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2dSame(960, 960, kernel_size=(3, 3), stride=(2, 2), groups=960, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          960, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(960, 40, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(40, 960, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(960, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (1): InvertedResidual(\n",
       "        (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (2): InvertedResidual(\n",
       "        (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (3): InvertedResidual(\n",
       "        (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (4): InvertedResidual(\n",
       "        (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (5): InvertedResidual(\n",
       "        (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (6): InvertedResidual(\n",
       "        (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (7): InvertedResidual(\n",
       "        (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (8): InvertedResidual(\n",
       "        (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (9): InvertedResidual(\n",
       "        (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (10): InvertedResidual(\n",
       "        (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (11): InvertedResidual(\n",
       "        (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (12): InvertedResidual(\n",
       "        (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (13): InvertedResidual(\n",
       "        (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "      (14): InvertedResidual(\n",
       "        (conv_pw): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (conv_dw): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1536, bias=False)\n",
       "        (bn2): BatchNormAct2d(\n",
       "          1536, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): SiLU(inplace=True)\n",
       "        )\n",
       "        (aa): Identity()\n",
       "        (se): SqueezeExcite(\n",
       "          (conv_reduce): Conv2d(1536, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (act1): SiLU(inplace=True)\n",
       "          (conv_expand): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n",
       "          (gate): Sigmoid()\n",
       "        )\n",
       "        (conv_pwl): Conv2d(1536, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNormAct2d(\n",
       "          256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "          (drop): Identity()\n",
       "          (act): Identity()\n",
       "        )\n",
       "        (drop_path): Identity()\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (conv_head): Conv2d(256, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "  (bn2): BatchNormAct2d(\n",
       "    1280, eps=0.001, momentum=0.1, affine=True, track_running_stats=True\n",
       "    (drop): Identity()\n",
       "    (act): SiLU(inplace=True)\n",
       "  )\n",
       "  (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
       "  (classifier): Linear(in_features=1280, out_features=1000, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = timm.create_model(' tf_efficientnetv2_s.in21k_ft_in1k', pretrained=True)\n",
    "data_config = timm.data.resolve_model_data_config(model)\n",
    "transforms = timm.data.create_transform(**data_config, is_training=False)\n",
    "model.eval()\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "52f2cf2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "for stage in model.model.blocks[0]:\n",
    "    print(stage.has_skip)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0e5d2866",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Compose(\n",
       "    Resize(size=300, interpolation=bicubic, max_size=None, antialias=True)\n",
       "    CenterCrop(size=(300, 300))\n",
       "    MaybeToTensor()\n",
       "    Normalize(mean=tensor([0.5000, 0.5000, 0.5000]), std=tensor([0.5000, 0.5000, 0.5000]))\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dd8e8ccf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "===============================================================================================\n",
       "Layer (type:depth-idx)                        Output Shape              Param #\n",
       "===============================================================================================\n",
       "EfficientNet                                  [1, 1000]                 --\n",
       "├─Conv2dSame: 1-1                             [1, 24, 150, 150]         648\n",
       "├─BatchNormAct2d: 1-2                         [1, 24, 150, 150]         48\n",
       "│    └─Identity: 2-1                          [1, 24, 150, 150]         --\n",
       "│    └─SiLU: 2-2                              [1, 24, 150, 150]         --\n",
       "├─Sequential: 1-3                             [1, 256, 10, 10]          --\n",
       "│    └─Sequential: 2-3                        [1, 24, 150, 150]         --\n",
       "│    │    └─ConvBnAct: 3-1                    [1, 24, 150, 150]         5,232\n",
       "│    │    └─ConvBnAct: 3-2                    [1, 24, 150, 150]         5,232\n",
       "│    └─Sequential: 2-4                        [1, 48, 75, 75]           --\n",
       "│    │    └─EdgeResidual: 3-3                 [1, 48, 75, 75]           25,632\n",
       "│    │    └─EdgeResidual: 3-4                 [1, 48, 75, 75]           92,640\n",
       "│    │    └─EdgeResidual: 3-5                 [1, 48, 75, 75]           92,640\n",
       "│    │    └─EdgeResidual: 3-6                 [1, 48, 75, 75]           92,640\n",
       "│    └─Sequential: 2-5                        [1, 64, 38, 38]           --\n",
       "│    │    └─EdgeResidual: 3-7                 [1, 64, 38, 38]           95,744\n",
       "│    │    └─EdgeResidual: 3-8                 [1, 64, 38, 38]           164,480\n",
       "│    │    └─EdgeResidual: 3-9                 [1, 64, 38, 38]           164,480\n",
       "│    │    └─EdgeResidual: 3-10                [1, 64, 38, 38]           164,480\n",
       "│    └─Sequential: 2-6                        [1, 128, 19, 19]          --\n",
       "│    │    └─InvertedResidual: 3-11            [1, 128, 19, 19]          61,200\n",
       "│    │    └─InvertedResidual: 3-12            [1, 128, 19, 19]          171,296\n",
       "│    │    └─InvertedResidual: 3-13            [1, 128, 19, 19]          171,296\n",
       "│    │    └─InvertedResidual: 3-14            [1, 128, 19, 19]          171,296\n",
       "│    │    └─InvertedResidual: 3-15            [1, 128, 19, 19]          171,296\n",
       "│    │    └─InvertedResidual: 3-16            [1, 128, 19, 19]          171,296\n",
       "│    └─Sequential: 2-7                        [1, 160, 19, 19]          --\n",
       "│    │    └─InvertedResidual: 3-17            [1, 160, 19, 19]          281,440\n",
       "│    │    └─InvertedResidual: 3-18            [1, 160, 19, 19]          397,800\n",
       "│    │    └─InvertedResidual: 3-19            [1, 160, 19, 19]          397,800\n",
       "│    │    └─InvertedResidual: 3-20            [1, 160, 19, 19]          397,800\n",
       "│    │    └─InvertedResidual: 3-21            [1, 160, 19, 19]          397,800\n",
       "│    │    └─InvertedResidual: 3-22            [1, 160, 19, 19]          397,800\n",
       "│    │    └─InvertedResidual: 3-23            [1, 160, 19, 19]          397,800\n",
       "│    │    └─InvertedResidual: 3-24            [1, 160, 19, 19]          397,800\n",
       "│    │    └─InvertedResidual: 3-25            [1, 160, 19, 19]          397,800\n",
       "│    └─Sequential: 2-8                        [1, 256, 10, 10]          --\n",
       "│    │    └─InvertedResidual: 3-26            [1, 256, 10, 10]          490,152\n",
       "│    │    └─InvertedResidual: 3-27            [1, 256, 10, 10]          1,005,120\n",
       "│    │    └─InvertedResidual: 3-28            [1, 256, 10, 10]          1,005,120\n",
       "│    │    └─InvertedResidual: 3-29            [1, 256, 10, 10]          1,005,120\n",
       "│    │    └─InvertedResidual: 3-30            [1, 256, 10, 10]          1,005,120\n",
       "│    │    └─InvertedResidual: 3-31            [1, 256, 10, 10]          1,005,120\n",
       "│    │    └─InvertedResidual: 3-32            [1, 256, 10, 10]          1,005,120\n",
       "│    │    └─InvertedResidual: 3-33            [1, 256, 10, 10]          1,005,120\n",
       "│    │    └─InvertedResidual: 3-34            [1, 256, 10, 10]          1,005,120\n",
       "│    │    └─InvertedResidual: 3-35            [1, 256, 10, 10]          1,005,120\n",
       "│    │    └─InvertedResidual: 3-36            [1, 256, 10, 10]          1,005,120\n",
       "│    │    └─InvertedResidual: 3-37            [1, 256, 10, 10]          1,005,120\n",
       "│    │    └─InvertedResidual: 3-38            [1, 256, 10, 10]          1,005,120\n",
       "│    │    └─InvertedResidual: 3-39            [1, 256, 10, 10]          1,005,120\n",
       "│    │    └─InvertedResidual: 3-40            [1, 256, 10, 10]          1,005,120\n",
       "├─Conv2d: 1-4                                 [1, 1280, 10, 10]         327,680\n",
       "├─BatchNormAct2d: 1-5                         [1, 1280, 10, 10]         2,560\n",
       "│    └─Identity: 2-9                          [1, 1280, 10, 10]         --\n",
       "│    └─SiLU: 2-10                             [1, 1280, 10, 10]         --\n",
       "├─SelectAdaptivePool2d: 1-6                   [1, 1280]                 --\n",
       "│    └─AdaptiveAvgPool2d: 2-11                [1, 1280, 1, 1]           --\n",
       "│    └─Flatten: 2-12                          [1, 1280]                 --\n",
       "├─Linear: 1-7                                 [1, 1000]                 1,281,000\n",
       "===============================================================================================\n",
       "Total params: 21,458,488\n",
       "Trainable params: 21,458,488\n",
       "Non-trainable params: 0\n",
       "Total mult-adds (Units.GIGABYTES): 5.31\n",
       "===============================================================================================\n",
       "Input size (MB): 1.08\n",
       "Forward/backward pass size (MB): 181.88\n",
       "Params size (MB): 85.22\n",
       "Estimated Total Size (MB): 268.17\n",
       "==============================================================================================="
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summary(model, input_size = (1, 3, 300, 300))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "AD",
   "language": "python",
   "name": "ad"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
