{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import os\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"2\"\n",
    "import torchvision\n",
    "from voc import *\n",
    "from coco import *\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.models import resnet152, resnet101, resnet18, resnet34, resnet50\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from config import seed_everything\n",
    "seed_everything(0)\n",
    "\n",
    "from models import *\n",
    "from backbones.config import config\n",
    "import pathlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'resnet50': 'resnetv2_50x3_bitm_in21k', 'swin': 'swin_base_patch4_window12_384_in22k', 'swin_large': 'swin_large_patch4_window12_384_in22k', 'convnext': 'convnext_large_in22k', 'resnet101': 'resnetv2_101x1_bitm_in21k'}\n",
      "resnet50 : resnetv2_50x3_bitm_in21k\n",
      "swin : swin_base_patch4_window12_384_in22k\n",
      "swin_large : swin_large_patch4_window12_384_in22k\n",
      "convnext : convnext_large_in22k\n",
      "resnet101 : resnetv2_101x1_bitm_in21k\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  /opt/conda/conda-bld/pytorch_1646756402876/work/aten/src/ATen/native/TensorShape.cpp:2228.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "InterSwin\n",
      "2304 384\n",
      "torch.Size([1, 256, 112, 112])\n",
      "1 2048 1 1\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "print(config)\n",
    "for k, v in config.items():\n",
    "  print(\"{} : {}\".format(k, v))\n",
    "  pathlib.Path('../figures/{}'.format(k)).mkdir(parents=True, exist_ok=True) \n",
    "\n",
    "m_li = [\n",
    "  # base_resnet50(model_path=config['resnet50'], num_classes=80, image_size=224, pretrained=True),\\\n",
    "  #  base_vit(config['vit'], 80, image_size=224, pretrained=True),\\\n",
    "      base_swin(config['swin_large'], 80, image_size=384, pretrained=True),\\\n",
    "        #  base_convnext(config['convnext'], 80, image_size=224, pretrained=True), \\\n",
    "          #  base_mlpmixer(config['mlpmixer'], num_classes=80, image_size=224, pretrained=True), \\\n",
    "             base_resnet101(model_path=config['resnet101'], num_classes=80, image_size=448, pretrained=True)\n",
    "             ]\n",
    "# m_li2 = [BaseResnet(m_li[0], 80)]\n",
    "p_li = [\n",
    "  # '/home/seongha/LT-ML/checkpoint/coco/coco_LT(0)_label_cnt_in21k-4-4-0_resnet50_base_best.pth.tar', \\\n",
    "  # '/home/seongha/LT-ML/checkpoint/coco/coco_LT(0)_label_cnt_in21k-4-4-0_vit_base_best.pth.tar', \\\n",
    "    # '/home/seongha/LT-ML/checkpoint/coco/coco_LT(0)_label_cnt_in22k-4-4-0_swin_base_best.pth.tar',\\\n",
    "  # '/home/seongha/LT-ML/checkpoint/coco/coco_LT(0)_label_cnt_in22k-4-4-0_convnext_base_best.pth.tar',\\\n",
    "      # '/home/seongha/LT-ML/checkpoint/coco/coco_LT(0)_label_cnt_in21k-4-4-0_mlpmixer_base_best.pth.tar' ,\\\n",
    "      # 'checkpoint/coco/coco_bs50-384-scheduler_swin_large_base_best.pth.tar',\\\n",
    "      'checkpoint/coco/coco_l_alpha-@0_swin_large_base_best.pth.tar',\\\n",
    "        '../checkpoint/coco/coco_bs64-448-scheduler_resnet101_base_best.pth.tar  '\n",
    "  ]\n",
    "\n",
    "def get_model(index):\n",
    "  path = p_li[index]\n",
    "  model = m_li[index]\n",
    "  di = torch.load(path)\n",
    "  print(di['best_score'])\n",
    "  print(di.keys())\n",
    "  model.load_state_dict(di['state_dict'])\n",
    "\n",
    "  for n, p in model.named_parameters():\n",
    "    if p.requires_grad==False:\n",
    "      p.requires_grad=True\n",
    "  return model\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pool with window of size=3, stride=2\n",
    "m = nn.AvgPool1d(3, stride=2)\n",
    "m(torch.tensor([[[1.,2,3,4,5,6,7]]])) #1,1,7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([20, 768, 1, 1])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# pool of square window of size=3, stride=2\n",
    "m = nn.AvgPool2d(3, stride=2)\n",
    "# pool of non-square window\n",
    "m = nn.AvgPool2d((63, 63), stride=(1, 1))\n",
    "input = torch.randn(20, 768, 63, 63)\n",
    "output = m(input)\n",
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder_layer = nn.TransformerDecoderLayer(d_model=768, nhead=1)\n",
    "memory = torch.rand(32, 196, 768)\n",
    "m = nn.AvgPool2d((63, 63), stride=(1, 1))\n",
    "memory = m(memory)\n",
    "memory = torch.squeeze(memory, -1)\n",
    "memory = torch.squeeze(memory, -1)\n",
    "tgt = torch.rand(32, 768)\n",
    "out = decoder_layer(tgt, memory) #32,768\n",
    "\n",
    "m = nn.AvgPool1d(196, stride=1)\n",
    "out = m(out)\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Calculate symmetric padding for a convolution\n",
    "def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:\n",
    "    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2\n",
    "    return padding\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([20, 2048, 1, 1])"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# With square kernels and equal stride\n",
    "m = nn.Conv2d(16, 33, 3, stride=2)\n",
    "# # non-square kernels and unequal stride and with padding\n",
    "# m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))\n",
    "# # non-square kernels and unequal stride and with padding and dilation\n",
    "m = nn.Conv2d(256, 2048, (31, 31), stride=(64, 64), dilation=(2, 2), padding=1)\n",
    "input = torch.randn(20, 256, 122, 122)\n",
    "output = m(input)\n",
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'nn' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb Cell 9'\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.143/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000016vscode-remote?line=0'>1</a>\u001b[0m \u001b[39m# pool of square window of size=3, stride=2\u001b[39;00m\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.143/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000016vscode-remote?line=1'>2</a>\u001b[0m m \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mAvgPool2d(\u001b[39m3\u001b[39m, stride\u001b[39m=\u001b[39m\u001b[39m2\u001b[39m)\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.143/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000016vscode-remote?line=2'>3</a>\u001b[0m \u001b[39m# pool of non-square window\u001b[39;00m\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.143/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000016vscode-remote?line=3'>4</a>\u001b[0m m \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mAvgPool2d((\u001b[39m3\u001b[39m, \u001b[39m2\u001b[39m), stride\u001b[39m=\u001b[39m(\u001b[39m2\u001b[39m, \u001b[39m1\u001b[39m))\n",
      "\u001b[0;31mNameError\u001b[0m: name 'nn' is not defined"
     ]
    }
   ],
   "source": [
    "# pool of square window of size=3, stride=2\n",
    "m = nn.AvgPool2d(3, stride=2)\n",
    "# pool of non-square window\n",
    "m = nn.AvgPool2d((3, 2), stride=(2, 1))\n",
    "input = torch.randn(20, 16, 50, 32)\n",
    "output = m(input)\n",
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'mAP': tensor(74.8108), 'OF1': 0.7193339297047961, 'CF1': 0.6759560773634149}\n",
      "dict_keys(['epoch', 'arch', 'state_dict', 'best_score'])\n",
      "torch.Size([1, 1024])\n",
      "torch.Size([1, 196, 512])\n",
      "torch.Size([1, 1024])\n",
      "torch.Size([1, 80])\n"
     ]
    }
   ],
   "source": [
    "#swin\n",
    "model = get_model(2)\n",
    "inp = torch.rand(3, 224, 224)\n",
    "inp = torch.unsqueeze(inp, 0)\n",
    "out = model.features(inp)\n",
    "print(out.shape)\n",
    "\n",
    "inp = torch.rand(3, 224, 224)\n",
    "inp = torch.unsqueeze(inp, 0)\n",
    "seq = nn.Sequential(*[model.features[0], model.features[1], model.features[2][0], model.features[2][1]]) #2-2-17-2 blocks\n",
    "out = seq(inp)\n",
    "print(out.shape) #[1, 784, 256], [1, 196, 512]\n",
    "b, n, h = out.shape\n",
    "out = out.reshape((1, -1))\n",
    "\n",
    "m = nn.AvgPool1d(n*h - 1024 + 1, stride=1)\n",
    "out = m(out).squeeze(-1)\n",
    "print(out.shape)\n",
    "\n",
    "logit = model.fc(out)\n",
    "print(logit.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#vit\n",
    "model = get_model(1)\n",
    "inp = torch.rand(3, 224, 224)\n",
    "inp = torch.unsqueeze(inp, 0)\n",
    "out = model.features(inp)\n",
    "print(out.shape)\n",
    "\n",
    "inp = torch.rand(3, 224, 224)\n",
    "inp = torch.unsqueeze(inp, 0)\n",
    "seq = nn.Sequential(*[model.features[0], model.features[1], model.features[2][0], model.features[2][1], model.features[2][2]])\n",
    "out = seq(inp)\n",
    "print(out.shape) #([1, 196, 768])\n",
    "b, n, h = out.shape\n",
    "\n",
    "out = torch.swapaxes(out, 1,2)\n",
    "m = nn.AvgPool1d(196, stride=1)\n",
    "out = m(out).squeeze(-1)\n",
    "print(out.shape)\n",
    "\n",
    "logit = model.fc(out)\n",
    "print(logit.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'ResNetV2' object has no attribute 'conv1'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb Cell 12'\u001b[0m in \u001b[0;36m<cell line: 6>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000011vscode-remote?line=3'>4</a>\u001b[0m inp \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrand(\u001b[39m3\u001b[39m, \u001b[39m488\u001b[39m, \u001b[39m488\u001b[39m)\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000011vscode-remote?line=4'>5</a>\u001b[0m inp \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39munsqueeze(inp, \u001b[39m0\u001b[39m)\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000011vscode-remote?line=5'>6</a>\u001b[0m seq \u001b[39m=\u001b[39m nn\u001b[39m.\u001b[39mSequential(\u001b[39m*\u001b[39m[model\u001b[39m.\u001b[39;49mconv1, model\u001b[39m.\u001b[39mbn1, model\u001b[39m.\u001b[39mact1, model\u001b[39m.\u001b[39mmaxpool, model\u001b[39m.\u001b[39mlayer1, model\u001b[39m.\u001b[39mlayer2, model\u001b[39m.\u001b[39mlayer3, model\u001b[39m.\u001b[39mlayer4])\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000011vscode-remote?line=6'>7</a>\u001b[0m out \u001b[39m=\u001b[39m seq(inp)\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000011vscode-remote?line=7'>8</a>\u001b[0m \u001b[39mprint\u001b[39m(out\u001b[39m.\u001b[39mshape) \u001b[39m#torch.Size([1, 2048, 16, 16])\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/nn/modules/module.py:1185\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1182'>1183</a>\u001b[0m     \u001b[39mif\u001b[39;00m name \u001b[39min\u001b[39;00m modules:\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1183'>1184</a>\u001b[0m         \u001b[39mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1184'>1185</a>\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m object has no attribute \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m   <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1185'>1186</a>\u001b[0m     \u001b[39mtype\u001b[39m(\u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, name))\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'ResNetV2' object has no attribute 'conv1'"
     ]
    }
   ],
   "source": [
    "#resnet50\n",
    "import timm\n",
    "model = timm.create_model(\"resnet50\", pretrained=True, num_classes=80)\n",
    "inp = torch.rand(3, 488, 488)\n",
    "inp = torch.unsqueeze(inp, 0)\n",
    "seq = nn.Sequential(*[model.conv1, model.bn1, model.act1, model.maxpool, model.layer1, model.layer2, model.layer3, model.layer4])\n",
    "out = seq(inp)\n",
    "print(out.shape) #torch.Size([1, 2048, 16, 16])\n",
    "\n",
    "inp = torch.rand(3, 488, 488)\n",
    "inp = torch.unsqueeze(inp, 0)\n",
    "seq = nn.Sequential(*[model.conv1, model.bn1, model.act1, model.maxpool, model.layer1])\n",
    "out = seq(inp)\n",
    "print(out.shape) #[1, 256, 122, 122])\n",
    "b, n, h, w = out.shape\n",
    "out = out.reshape((1, -1))\n",
    "print(out.shape) #3810304\n",
    "\n",
    "m = nn.AvgPool1d(n*h*w - 3810304 + 1, stride=1)\n",
    "# m = nn.Conv2d(n, )\n",
    "out = m(out)\n",
    "out = out.reshape((-1, 6144, 1, 1))\n",
    "print(out.shape)\n",
    "\n",
    "logit = model.head.fc(out)\n",
    "print(logit.shape)\n",
    "logit = logit.flatten()\n",
    "print(logit.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '../checkpoint/coco/coco_bs64-448-scheduler_resnet101_base_best.pth.tar  '",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb Cell 13'\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000012vscode-remote?line=0'>1</a>\u001b[0m \u001b[39m#resnet101\u001b[39;00m\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000012vscode-remote?line=1'>2</a>\u001b[0m model \u001b[39m=\u001b[39m get_model(\u001b[39m-\u001b[39;49m\u001b[39m1\u001b[39;49m)\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000012vscode-remote?line=2'>3</a>\u001b[0m inp \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrand(\u001b[39m3\u001b[39m, \u001b[39m224\u001b[39m, \u001b[39m224\u001b[39m)\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000012vscode-remote?line=3'>4</a>\u001b[0m inp \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39munsqueeze(inp, \u001b[39m0\u001b[39m)\n",
      "\u001b[1;32m/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb Cell 2'\u001b[0m in \u001b[0;36mget_model\u001b[0;34m(index)\u001b[0m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000001vscode-remote?line=25'>26</a>\u001b[0m path \u001b[39m=\u001b[39m p_li[index]\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000001vscode-remote?line=26'>27</a>\u001b[0m model \u001b[39m=\u001b[39m m_li[index]\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000001vscode-remote?line=27'>28</a>\u001b[0m di \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mload(path)\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000001vscode-remote?line=28'>29</a>\u001b[0m \u001b[39mprint\u001b[39m(di[\u001b[39m'\u001b[39m\u001b[39mbest_score\u001b[39m\u001b[39m'\u001b[39m])\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000001vscode-remote?line=29'>30</a>\u001b[0m \u001b[39mprint\u001b[39m(di\u001b[39m.\u001b[39mkeys())\n",
      "File \u001b[0;32m~/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py:699\u001b[0m, in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=695'>696</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m'\u001b[39m\u001b[39mencoding\u001b[39m\u001b[39m'\u001b[39m \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m pickle_load_args\u001b[39m.\u001b[39mkeys():\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=696'>697</a>\u001b[0m     pickle_load_args[\u001b[39m'\u001b[39m\u001b[39mencoding\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mutf-8\u001b[39m\u001b[39m'\u001b[39m\n\u001b[0;32m--> <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=698'>699</a>\u001b[0m \u001b[39mwith\u001b[39;00m _open_file_like(f, \u001b[39m'\u001b[39;49m\u001b[39mrb\u001b[39;49m\u001b[39m'\u001b[39;49m) \u001b[39mas\u001b[39;00m opened_file:\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=699'>700</a>\u001b[0m     \u001b[39mif\u001b[39;00m _is_zipfile(opened_file):\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=700'>701</a>\u001b[0m         \u001b[39m# The zipfile reader is going to advance the current file position.\u001b[39;00m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=701'>702</a>\u001b[0m         \u001b[39m# If we want to actually tail call to torch.jit.load, we need to\u001b[39;00m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=702'>703</a>\u001b[0m         \u001b[39m# reset back to the original position.\u001b[39;00m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=703'>704</a>\u001b[0m         orig_position \u001b[39m=\u001b[39m opened_file\u001b[39m.\u001b[39mtell()\n",
      "File \u001b[0;32m~/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py:231\u001b[0m, in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=228'>229</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_open_file_like\u001b[39m(name_or_buffer, mode):\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=229'>230</a>\u001b[0m     \u001b[39mif\u001b[39;00m _is_path(name_or_buffer):\n\u001b[0;32m--> <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=230'>231</a>\u001b[0m         \u001b[39mreturn\u001b[39;00m _open_file(name_or_buffer, mode)\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=231'>232</a>\u001b[0m     \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=232'>233</a>\u001b[0m         \u001b[39mif\u001b[39;00m \u001b[39m'\u001b[39m\u001b[39mw\u001b[39m\u001b[39m'\u001b[39m \u001b[39min\u001b[39;00m mode:\n",
      "File \u001b[0;32m~/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py:212\u001b[0m, in \u001b[0;36m_open_file.__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m    <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=210'>211</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, name, mode):\n\u001b[0;32m--> <a href='file:///home/seongha/anaconda3/envs/MGSSL/lib/python3.9/site-packages/torch/serialization.py?line=211'>212</a>\u001b[0m     \u001b[39msuper\u001b[39m(_open_file, \u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mopen\u001b[39;49m(name, mode))\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../checkpoint/coco/coco_bs64-448-scheduler_resnet101_base_best.pth.tar  '"
     ]
    }
   ],
   "source": [
    "#resnet101\n",
    "model = get_model(-1)\n",
    "inp = torch.rand(3, 224, 224)\n",
    "inp = torch.unsqueeze(inp, 0)\n",
    "out = model.features(inp)\n",
    "out = model.head.global_pool(out)\n",
    "print(out.shape) #torch.Size([1, 6144, 1, 1])\n",
    "\n",
    "inp = torch.rand(3, 224, 224)\n",
    "inp = torch.unsqueeze(inp, 0)\n",
    "seq = nn.Sequential(*[model.features[0], model.features[1][0], model.features[1][1],model.features[1][2]])\n",
    "out = seq(inp)\n",
    "print(out.shape) #[1, 3072, 14, 14])\n",
    "b, n, h, w = out.shape\n",
    "out = out.reshape((1, -1))\n",
    "print(out.shape)\n",
    "\n",
    "m = nn.AvgPool1d(n*h*w - 6144 + 1, stride=1)\n",
    "# m = nn.Conv2d(n, )\n",
    "out = m(out)\n",
    "out = out.reshape((-1, 6144, 1, 1))\n",
    "print(out.shape)\n",
    "\n",
    "logit = model.head.fc(out)\n",
    "print(logit.shape)\n",
    "logit = logit.flatten()\n",
    "print(logit.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "mlpmixer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = get_model(4)\n",
    "# import torch.nn.functional as F\n",
    "\n",
    "inp = torch.rand(3, 224, 224)\n",
    "inp = torch.unsqueeze(inp, 0)\n",
    "seq = nn.Sequential(*[model.features[0], model.features[1][0]])\n",
    "out = seq(inp)\n",
    "print(out.shape)\n",
    "out = torch.swapaxes(out, 1,2)\n",
    "m = nn.AvgPool1d(196, stride=1)\n",
    "out = m(out).squeeze(-1)\n",
    "# out = torch.mean(out.view(out.size(0), out.size(1), -1), dim=2)\n",
    "torch.mean(torch.stack([out, out]), 1)\n",
    "print(out.shape)\n",
    "\n",
    "logit = model.fc(out)\n",
    "print(logit.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'get_model' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb Cell 16'\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000018vscode-remote?line=0'>1</a>\u001b[0m model \u001b[39m=\u001b[39m get_model(\u001b[39m4\u001b[39m)\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000018vscode-remote?line=1'>2</a>\u001b[0m \u001b[39m# import torch.nn.functional as F\u001b[39;00m\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2B143.248.157.158/home/seongha/LT-ML/notebooks/intermediate_layer.ipynb#ch0000018vscode-remote?line=3'>4</a>\u001b[0m inp \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrand(\u001b[39m3\u001b[39m, \u001b[39m448\u001b[39m, \u001b[39m448\u001b[39m)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'get_model' is not defined"
     ]
    }
   ],
   "source": [
    "model = get_model(4)\n",
    "# import torch.nn.functional as F\n",
    "\n",
    "inp = torch.rand(3, 448, 448)\n",
    "inp = torch.unsqueeze(inp, 0)\n",
    "seq = nn.Sequential(*[model.features[0], model.features[1][0]])\n",
    "out = seq(inp)\n",
    "print(out.shape)\n",
    "out = torch.swapaxes(out, 1,2)\n",
    "m = nn.AvgPool1d(196, stride=1)\n",
    "out = m(out).squeeze(-1)\n",
    "# out = torch.mean(out.view(out.size(0), out.size(1), -1), dim=2)\n",
    "torch.mean(torch.stack([out, out]), 1)\n",
    "print(out.shape)\n",
    "\n",
    "logit = model.fc(out)\n",
    "print(logit.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Count number of training parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/seongha/anaconda3/envs/LTML/lib/python3.9/site-packages/torch/functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  /opt/conda/conda-bld/pytorch_1646756402876/work/aten/src/ATen/native/TensorShape.cpp:2228.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "InterSwin\n",
      "2304 384\n",
      "<class 'backbones.swin.InterSwin'>\n",
      "330\n",
      "3\n",
      "scale\n",
      "fc.weight\n",
      "fc.bias\n",
      "2\n",
      "<class 'backbones.resnet.InterResnetV2'>\n",
      "307\n",
      "3\n",
      "scale\n",
      "fc.weight\n",
      "fc.bias\n",
      "2\n"
     ]
    }
   ],
   "source": [
    "def count_training_params(model, path):\n",
    "  print(model.__class__)\n",
    "  di = torch.load(path)\n",
    "  model.load_state_dict(di['state_dict'])\n",
    "  ##full finetune\n",
    "  print(len(list(model.parameters())))\n",
    "\n",
    "  ## ours\n",
    "  print(np.sum(np.array(list(p.requires_grad for p in model.parameters()))))\n",
    "  for n, p in model.named_parameters():\n",
    "    if p.requires_grad:\n",
    "      print(n)\n",
    "\n",
    "  ## linear probe\n",
    "  cnt=0\n",
    "  for n, p in model.fc.named_parameters():\n",
    "    if p.requires_grad:\n",
    "      cnt+=1\n",
    "  print(cnt)\n",
    "\n",
    "model = base_swin(config['swin_large'], 80, 384, True, cond=True, where=0)\n",
    "path='../checkpoint/coco/coco_bs50-at0-lr_scheduler_swin_large_base_best.pth.tar'\n",
    "count_training_params(model,path)\n",
    "path='../checkpoint/coco/coco_@0-bs64-448-scheduler_resnet101_base_best.pth.tar'\n",
    "model = base_resnet101(config['resnet101'], 80, 448, True, cond=True, where=0)\n",
    "count_training_params(model,path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "765b26ec2fc4c2066cf7ce8a0dde5a8255de29dd3973b3be957926608459ba30"
  },
  "kernelspec": {
   "display_name": "Python 3.10.4 ('MGSSL')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
