{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f14e6906-e97f-413d-85e7-f34073aeab0a",
   "metadata": {},
   "source": [
    "Here we demo how to construct an LRM network manually, and load our lrm3, lrm2, and lrm1 alexnet models. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "01338bf3-aed9-4f7b-897b-09765103168d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import models\n",
    "from models import LRMNet\n",
    "from torchvision.models import alexnet"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5c1ae1e-32d3-4f92-8436-9e687010eb60",
   "metadata": {},
   "source": [
    "## manually construct alexnet_lrm network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1e88ef21-8f9b-4158-bc80-ff8403e89497",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LRMNet(\n",
       "  (backbone): AlexNet(\n",
       "    (features): Sequential(\n",
       "      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
       "      (1): ReLU()\n",
       "      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
       "      (4): ReLU(inplace=True)\n",
       "      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (7): ReLU(inplace=True)\n",
       "      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (9): ReLU()\n",
       "      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (11): ReLU(inplace=True)\n",
       "      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n",
       "    (classifier): Sequential(\n",
       "      (0): Dropout(p=0.5, inplace=False)\n",
       "      (1): Linear(in_features=9216, out_features=4096, bias=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "      (3): Dropout(p=0.5, inplace=False)\n",
       "      (4): Linear(in_features=4096, out_features=4096, bias=True)\n",
       "      (5): ReLU(inplace=True)\n",
       "      (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (lrm): Sequential(\n",
       "    (features8_modulation): LongRangeModulation(\n",
       "      (from_classifier_6_to_features_8): ModBlock(\n",
       "        (rescale): NormSquashResize(\n",
       "          (norm): ChannelNorm((1000,), eps=1e-05, elementwise_affine=True)\n",
       "          (squash): FeedbackScale(mode='tanh')\n",
       "          (interp): AddSpatialDimension()\n",
       "        )\n",
       "        (modulation): Conv2d(1000, 256, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (pre_mod_output): Identity()\n",
       "      (total_mod): Identity()\n",
       "      (post_mod_output): Identity()\n",
       "    )\n",
       "    (features0_modulation): LongRangeModulation(\n",
       "      (from_features_9_to_features_0): ModBlock(\n",
       "        (rescale): NormSquashResize(\n",
       "          (norm): AdaptiveFullstackNorm((256, 55, 55), eps=1e-05, elementwise_affine=True)\n",
       "          (squash): FeedbackScale(mode='tanh')\n",
       "          (interp): AdaptiveUpsample(upsample_mode='UpsampleBilinear')\n",
       "        )\n",
       "        (modulation): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (pre_mod_output): Identity()\n",
       "      (total_mod): Identity()\n",
       "      (post_mod_output): Identity()\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# load backbone\n",
    "backbone = alexnet()\n",
    "\n",
    "# Feedback is specified as a tuple with a 'target_layer' and list of [source_layers]\n",
    "# - features.8 recieves feedback from classifier.6\n",
    "# - features.0 receives feedback from features.9\n",
    "\n",
    "mod_connections = [ \n",
    "    ('features.8', ['classifier.6']),\n",
    "    ('features.0', ['features.9']),\n",
    "]\n",
    "\n",
    "# create the model with default number of forward passes (time_steps) and expected img_size\n",
    "# - default number of forward passes can be overridden in forward pass\n",
    "# - actual input img_size can be any size and LRMNet will adapt feedback size automatically\n",
    "model = LRMNet(backbone, mod_connections, time_steps=2, img_size=224)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "471e81e2-54f6-4e20-afc9-3d462ca6a84b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 1000])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(5,3,224,224)\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    out = model(x)\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d69ff63-5817-46b6-a144-ee20beafdc50",
   "metadata": {},
   "source": [
    "## lrm3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "73afd361-778d-4e5c-a6f6-652cf1fc812b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LRMNet(\n",
       "  (backbone): AlexNet(\n",
       "    (features): Sequential(\n",
       "      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
       "      (1): ReLU()\n",
       "      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
       "      (4): ReLU()\n",
       "      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (7): ReLU()\n",
       "      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (9): ReLU()\n",
       "      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (11): ReLU()\n",
       "      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n",
       "    (classifier): Sequential(\n",
       "      (0): Dropout(p=0.5, inplace=False)\n",
       "      (1): Linear(in_features=9216, out_features=4096, bias=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "      (3): Dropout(p=0.5, inplace=False)\n",
       "      (4): Linear(in_features=4096, out_features=4096, bias=True)\n",
       "      (5): ReLU(inplace=True)\n",
       "      (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (lrm): Sequential(\n",
       "    (features0_modulation): LongRangeModulation(\n",
       "      (from_features_9_to_features_0): ModBlock(\n",
       "        (rescale): NormSquashResize(\n",
       "          (norm): AdaptiveFullstackNorm((256, 55, 55), eps=1e-05, elementwise_affine=True)\n",
       "          (squash): FeedbackScale(mode='tanh')\n",
       "          (interp): AdaptiveUpsample(upsample_mode='UpsampleBilinear')\n",
       "        )\n",
       "        (modulation): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (pre_mod_output): Identity()\n",
       "      (total_mod): Identity()\n",
       "      (post_mod_output): Identity()\n",
       "    )\n",
       "    (features3_modulation): LongRangeModulation(\n",
       "      (from_features_12_to_features_3): ModBlock(\n",
       "        (rescale): NormSquashResize(\n",
       "          (norm): AdaptiveFullstackNorm((256, 27, 27), eps=1e-05, elementwise_affine=True)\n",
       "          (squash): FeedbackScale(mode='tanh')\n",
       "          (interp): AdaptiveUpsample(upsample_mode='UpsampleBilinear')\n",
       "        )\n",
       "        (modulation): Conv2d(256, 192, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (pre_mod_output): Identity()\n",
       "      (total_mod): Identity()\n",
       "      (post_mod_output): Identity()\n",
       "    )\n",
       "    (features6_modulation): LongRangeModulation(\n",
       "      (from_classifier_2_to_features_6): ModBlock(\n",
       "        (rescale): NormSquashResize(\n",
       "          (norm): ChannelNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
       "          (squash): FeedbackScale(mode='tanh')\n",
       "          (interp): AddSpatialDimension()\n",
       "        )\n",
       "        (modulation): Conv2d(4096, 384, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (pre_mod_output): Identity()\n",
       "      (total_mod): Identity()\n",
       "      (post_mod_output): Identity()\n",
       "    )\n",
       "    (features8_modulation): LongRangeModulation(\n",
       "      (from_classifier_6_to_features_8): ModBlock(\n",
       "        (rescale): NormSquashResize(\n",
       "          (norm): ChannelNorm((1000,), eps=1e-05, elementwise_affine=True)\n",
       "          (squash): FeedbackScale(mode='tanh')\n",
       "          (interp): AddSpatialDimension()\n",
       "        )\n",
       "        (modulation): Conv2d(1000, 256, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (pre_mod_output): Identity()\n",
       "      (total_mod): Identity()\n",
       "      (post_mod_output): Identity()\n",
       "    )\n",
       "    (features10_modulation): LongRangeModulation(\n",
       "      (from_classifier_6_to_features_10): ModBlock(\n",
       "        (rescale): NormSquashResize(\n",
       "          (norm): ChannelNorm((1000,), eps=1e-05, elementwise_affine=True)\n",
       "          (squash): FeedbackScale(mode='tanh')\n",
       "          (interp): AddSpatialDimension()\n",
       "        )\n",
       "        (modulation): Conv2d(1000, 256, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (pre_mod_output): Identity()\n",
       "      (total_mod): Identity()\n",
       "      (post_mod_output): Identity()\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = models.alexnet_lrm3()\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bd9a7e7e-4392-4fad-a79b-cdebd66dd7cf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 1000])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(5,3,224,224)\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    out = model(x)\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ab1e6d3-f6bd-411e-826a-7fd261f91668",
   "metadata": {},
   "source": [
    "## lrm2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ca186c50-3e92-4e93-bbaf-0582734e95ec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LRMNet(\n",
       "  (backbone): AlexNet(\n",
       "    (features): Sequential(\n",
       "      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
       "      (1): ReLU()\n",
       "      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
       "      (4): ReLU()\n",
       "      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (7): ReLU(inplace=True)\n",
       "      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (9): ReLU()\n",
       "      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (11): ReLU()\n",
       "      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n",
       "    (classifier): Sequential(\n",
       "      (0): Dropout(p=0.5, inplace=False)\n",
       "      (1): Linear(in_features=9216, out_features=4096, bias=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "      (3): Dropout(p=0.5, inplace=False)\n",
       "      (4): Linear(in_features=4096, out_features=4096, bias=True)\n",
       "      (5): ReLU(inplace=True)\n",
       "      (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (lrm): Sequential(\n",
       "    (features0_modulation): LongRangeModulation(\n",
       "      (from_features_9_to_features_0): ModBlock(\n",
       "        (rescale): NormSquashResize(\n",
       "          (norm): AdaptiveFullstackNorm((256, 55, 55), eps=1e-05, elementwise_affine=True)\n",
       "          (squash): FeedbackScale(mode='tanh')\n",
       "          (interp): AdaptiveUpsample(upsample_mode='UpsampleBilinear')\n",
       "        )\n",
       "        (modulation): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (pre_mod_output): Identity()\n",
       "      (total_mod): Identity()\n",
       "      (post_mod_output): Identity()\n",
       "    )\n",
       "    (features3_modulation): LongRangeModulation(\n",
       "      (from_features_12_to_features_3): ModBlock(\n",
       "        (rescale): NormSquashResize(\n",
       "          (norm): AdaptiveFullstackNorm((256, 27, 27), eps=1e-05, elementwise_affine=True)\n",
       "          (squash): FeedbackScale(mode='tanh')\n",
       "          (interp): AdaptiveUpsample(upsample_mode='UpsampleBilinear')\n",
       "        )\n",
       "        (modulation): Conv2d(256, 192, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (pre_mod_output): Identity()\n",
       "      (total_mod): Identity()\n",
       "      (post_mod_output): Identity()\n",
       "    )\n",
       "    (features8_modulation): LongRangeModulation(\n",
       "      (from_classifier_6_to_features_8): ModBlock(\n",
       "        (rescale): NormSquashResize(\n",
       "          (norm): ChannelNorm((1000,), eps=1e-05, elementwise_affine=True)\n",
       "          (squash): FeedbackScale(mode='tanh')\n",
       "          (interp): AddSpatialDimension()\n",
       "        )\n",
       "        (modulation): Conv2d(1000, 256, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (pre_mod_output): Identity()\n",
       "      (total_mod): Identity()\n",
       "      (post_mod_output): Identity()\n",
       "    )\n",
       "    (features10_modulation): LongRangeModulation(\n",
       "      (from_classifier_6_to_features_10): ModBlock(\n",
       "        (rescale): NormSquashResize(\n",
       "          (norm): ChannelNorm((1000,), eps=1e-05, elementwise_affine=True)\n",
       "          (squash): FeedbackScale(mode='tanh')\n",
       "          (interp): AddSpatialDimension()\n",
       "        )\n",
       "        (modulation): Conv2d(1000, 256, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (pre_mod_output): Identity()\n",
       "      (total_mod): Identity()\n",
       "      (post_mod_output): Identity()\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = models.alexnet_lrm2()\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b1b123ea-bd09-4b67-bedb-5b3094c4c48f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 1000])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(5,3,224,224)\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    out = model(x)\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a29e0ad-d17a-47df-9a78-5fbe953fc114",
   "metadata": {},
   "source": [
    "## lrm1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "68d0f374-30f1-4459-bcbb-acc414fc3ab8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LRMNet(\n",
       "  (backbone): AlexNet(\n",
       "    (features): Sequential(\n",
       "      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
       "      (1): ReLU()\n",
       "      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
       "      (4): ReLU(inplace=True)\n",
       "      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (7): ReLU(inplace=True)\n",
       "      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (9): ReLU()\n",
       "      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (11): ReLU(inplace=True)\n",
       "      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n",
       "    (classifier): Sequential(\n",
       "      (0): Dropout(p=0.5, inplace=False)\n",
       "      (1): Linear(in_features=9216, out_features=4096, bias=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "      (3): Dropout(p=0.5, inplace=False)\n",
       "      (4): Linear(in_features=4096, out_features=4096, bias=True)\n",
       "      (5): ReLU(inplace=True)\n",
       "      (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (lrm): Sequential(\n",
       "    (features8_modulation): LongRangeModulation(\n",
       "      (from_classifier_6_to_features_8): ModBlock(\n",
       "        (rescale): NormSquashResize(\n",
       "          (norm): ChannelNorm((1000,), eps=1e-05, elementwise_affine=True)\n",
       "          (squash): FeedbackScale(mode='tanh')\n",
       "          (interp): AddSpatialDimension()\n",
       "        )\n",
       "        (modulation): Conv2d(1000, 256, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (pre_mod_output): Identity()\n",
       "      (total_mod): Identity()\n",
       "      (post_mod_output): Identity()\n",
       "    )\n",
       "    (features0_modulation): LongRangeModulation(\n",
       "      (from_features_9_to_features_0): ModBlock(\n",
       "        (rescale): NormSquashResize(\n",
       "          (norm): AdaptiveFullstackNorm((256, 55, 55), eps=1e-05, elementwise_affine=True)\n",
       "          (squash): FeedbackScale(mode='tanh')\n",
       "          (interp): AdaptiveUpsample(upsample_mode='UpsampleBilinear')\n",
       "        )\n",
       "        (modulation): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (pre_mod_output): Identity()\n",
       "      (total_mod): Identity()\n",
       "      (post_mod_output): Identity()\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = models.alexnet_lrm1()\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b4511ad9-d0e6-4e58-8405-b58326923080",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 1000])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(5,3,224,224)\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    out = model(x)\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ad903662-a1da-4887-90db-518c2121a153",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LRMNet(\n",
       "  (backbone): AlexNet(\n",
       "    (features): Sequential(\n",
       "      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
       "      (1): ReLU(inplace=True)\n",
       "      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
       "      (4): ReLU(inplace=True)\n",
       "      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (7): ReLU(inplace=True)\n",
       "      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (9): ReLU(inplace=True)\n",
       "      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (11): ReLU(inplace=True)\n",
       "      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n",
       "    (classifier): Sequential(\n",
       "      (0): Dropout(p=0.5, inplace=False)\n",
       "      (1): Linear(in_features=9216, out_features=4096, bias=True)\n",
       "      (2): ReLU()\n",
       "      (3): Dropout(p=0.5, inplace=False)\n",
       "      (4): Linear(in_features=4096, out_features=4096, bias=True)\n",
       "      (5): ReLU(inplace=True)\n",
       "      (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (lrm): Sequential(\n",
       "    (classifier1_modulation): LongRangeModulation(\n",
       "      (from_classifier_6_to_classifier_1): ModBlock(\n",
       "        (rescale): NormSquashResize(\n",
       "          (norm): ChannelNorm((1000,), eps=1e-05, elementwise_affine=True)\n",
       "          (squash): FeedbackScale(mode='tanh')\n",
       "          (interp): AddSpatialDimension()\n",
       "        )\n",
       "        (modulation): Conv2d(1000, 4096, kernel_size=(1, 1), stride=(1, 1))\n",
       "      )\n",
       "      (pre_mod_output): Identity()\n",
       "      (total_mod): Identity()\n",
       "      (post_mod_output): Identity()\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# load backbone\n",
    "backbone = alexnet()\n",
    "\n",
    "# Feedback is specified as a tuple with a 'target_layer' and list of [source_layers]\n",
    "# - features.8 recieves feedback from classifier.6\n",
    "# - features.0 receives feedback from features.9\n",
    "\n",
    "mod_connections = [ \n",
    "    ('classifier.1', ['classifier.6']),\n",
    "]\n",
    "\n",
    "# create the model with default number of forward passes (time_steps) and expected img_size\n",
    "# - default number of forward passes can be overridden in forward pass\n",
    "# - actual input img_size can be any size and LRMNet will adapt feedback size automatically\n",
    "model = LRMNet(backbone, mod_connections, time_steps=2, img_size=224)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fb186e72-816e-4f10-89c8-d782bb7d260d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 4096, 5, 1000])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(5,3,224,224)\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    out = model(x)\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d1d8c6b-0b71-4001-8cf9-b19fbde29628",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ffcv",
   "language": "python",
   "name": "ffcv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
