{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Getting inference working on a timm-pretrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = \"hf_hub:SamAdamDay/resnet18_cifar10\"\n",
    "\n",
    "ROOT_DIR = \"~/Code/Projects/PVG Experiments/data/image_classification/cifar10/raw/\"\n",
    "\n",
    "BATCH_SIZE = 256\n",
    "\n",
    "FORCE_CPU = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from contextlib import suppress\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from torchvision.datasets import CIFAR10\n",
    "\n",
    "import timm\n",
    "from timm import create_model\n",
    "from timm.models import ResNet\n",
    "from timm.data import (\n",
    "    resolve_data_config,\n",
    "    create_transform,\n",
    "    create_dataset,\n",
    "    create_loader,\n",
    ")\n",
    "from timm.utils import accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "if torch.cuda.is_available() and not FORCE_CPU:\n",
    "    device = torch.device(\"cuda\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResNet(\n",
       "  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (act1): ReLU(inplace=True)\n",
       "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  (layer1): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (drop_block): Identity()\n",
       "      (act1): ReLU(inplace=True)\n",
       "      (aa): Identity()\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act2): ReLU(inplace=True)\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (drop_block): Identity()\n",
       "      (act1): ReLU(inplace=True)\n",
       "      (aa): Identity()\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act2): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer2): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (drop_block): Identity()\n",
       "      (act1): ReLU(inplace=True)\n",
       "      (aa): Identity()\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act2): ReLU(inplace=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (drop_block): Identity()\n",
       "      (act1): ReLU(inplace=True)\n",
       "      (aa): Identity()\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act2): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer3): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (drop_block): Identity()\n",
       "      (act1): ReLU(inplace=True)\n",
       "      (aa): Identity()\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act2): ReLU(inplace=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (drop_block): Identity()\n",
       "      (act1): ReLU(inplace=True)\n",
       "      (aa): Identity()\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act2): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer4): Sequential(\n",
       "    (0): BasicBlock(\n",
       "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (drop_block): Identity()\n",
       "      (act1): ReLU(inplace=True)\n",
       "      (aa): Identity()\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act2): ReLU(inplace=True)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): BasicBlock(\n",
       "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (drop_block): Identity()\n",
       "      (act1): ReLU(inplace=True)\n",
       "      (aa): Identity()\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act2): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
       "  (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = create_model(MODEL_NAME, pretrained=True)\n",
    "model.to(device)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1000"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_classes = model.num_classes\n",
    "num_classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_size': [3, 288, 288],\n",
       " 'interpolation': 'bicubic',\n",
       " 'mean': [0.485, 0.456, 0.406],\n",
       " 'std': [0.229, 0.224, 0.225],\n",
       " 'crop_pct': 1.0,\n",
       " 'crop_mode': 'center'}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_config = resolve_data_config(\n",
    "    model=model,\n",
    "    use_test_size=True,\n",
    "    verbose=True,\n",
    ")\n",
    "data_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Compose(\n",
       "    Resize(size=288, interpolation=bicubic, max_size=None, antialias=True)\n",
       "    CenterCrop(size=[288, 288])\n",
       "    ToTensor()\n",
       "    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transform = create_transform(\n",
    "    input_size=data_config[\"input_size\"],\n",
    "    is_training=False,\n",
    "    no_aug=False,\n",
    "    train_crop_mode=None,\n",
    "    scale=None,\n",
    "    ratio=None,\n",
    "    hflip=0.5,\n",
    "    vflip=0.,\n",
    "    color_jitter=0.4,\n",
    "    color_jitter_prob=None,\n",
    "    grayscale_prob=0.,\n",
    "    gaussian_blur_prob=0.,\n",
    "    auto_augment=None,\n",
    "    interpolation=data_config['interpolation'],\n",
    "    mean=data_config['mean'],\n",
    "    std=data_config['std'],\n",
    "    crop_pct=data_config['crop_pct'],\n",
    "    crop_mode=data_config['crop_mode'],\n",
    "    crop_border_pixels=None,\n",
    "    re_prob=0.,\n",
    "    re_mode=\"const\",\n",
    "    re_count=1,\n",
    "    re_num_splits=0,\n",
    "    tf_preprocessing=False,\n",
    "    use_prefetcher=False,\n",
    "    separate=False,\n",
    ")\n",
    "transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Dataset CIFAR10\n",
       "    Number of datapoints: 10000\n",
       "    Root location: /home/sam/Code/Projects/PVG Experiments/data/image_classification/cifar10/raw/\n",
       "    Split: Test\n",
       "    StandardTransform\n",
       "Transform: Compose(\n",
       "               Resize(size=288, interpolation=bicubic, max_size=None, antialias=True)\n",
       "               CenterCrop(size=[288, 288])\n",
       "               ToTensor()\n",
       "               Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))\n",
       "           )"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# dataset = create_dataset(\n",
    "#     root=ROOT_DIR,\n",
    "#     name=\"torch/cifar10\",\n",
    "#     split=\"validation\",\n",
    "#     download=True,\n",
    "#     load_bytes=False,\n",
    "#     class_map=\"\",\n",
    "#     num_samples=None,\n",
    "#     input_key=None,\n",
    "#     input_img_mode=\"RGB\",\n",
    "#     target_key=None,\n",
    "# )\n",
    "dataset = CIFAR10(\n",
    "    root=ROOT_DIR,\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=transform,\n",
    ")\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.utils.data.dataloader.DataLoader at 0x7f7aa012fa10>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# loader = create_loader(\n",
    "#     dataset,\n",
    "#     input_size=data_config[\"input_size\"],\n",
    "#     batch_size=BATCH_SIZE,\n",
    "#     use_prefetcher=False,\n",
    "#     interpolation=data_config[\"interpolation\"],\n",
    "#     mean=data_config[\"mean\"],\n",
    "#     std=data_config[\"std\"],\n",
    "#     num_workers=4,\n",
    "#     crop_pct=data_config[\"crop_pct\"],\n",
    "#     crop_mode=data_config[\"crop_mode\"],\n",
    "#     crop_border_pixels=None,\n",
    "#     pin_memory=False,\n",
    "#     device=device,\n",
    "#     tf_preprocessing=False,\n",
    "# )\n",
    "loader = DataLoader(\n",
    "    dataset,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=False,\n",
    "    num_workers=1,\n",
    "    drop_last=False,\n",
    ")\n",
    "loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Compose(\n",
       "    Resize(size=288, interpolation=bicubic, max_size=None, antialias=True)\n",
       "    CenterCrop(size=[288, 288])\n",
       "    ToTensor()\n",
       "    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.transform"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch [1/40]: Top-1 accuracy: 93.75, Top-5 accuracy: 99.609375\n",
      "Batch [2/40]: Top-1 accuracy: 92.96875, Top-5 accuracy: 100.0\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[11], line 10\u001b[0m\n\u001b[1;32m      6\u001b[0m output \u001b[38;5;241m=\u001b[39m model(\u001b[38;5;28minput\u001b[39m)\n\u001b[1;32m      7\u001b[0m acc1, acc5 \u001b[38;5;241m=\u001b[39m accuracy(output, target, topk\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m5\u001b[39m))\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28mprint\u001b[39m(\n\u001b[1;32m      9\u001b[0m     \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBatch [\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbatch_idx\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(loader)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m]: \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 10\u001b[0m     \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTop-1 accuracy: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[43macc1\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m     11\u001b[0m     \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTop-5 accuracy: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00macc5\u001b[38;5;241m.\u001b[39mitem()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m     12\u001b[0m )\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for batch_idx, (input, target) in enumerate(loader):\n",
    "        target = target.to(device)\n",
    "        input = input.to(device)\n",
    "        output = model(input)\n",
    "        acc1, acc5 = accuracy(output, target, topk=(1, 5))\n",
    "        print(\n",
    "            f\"Batch [{batch_idx+1}/{len(loader)}]: \"\n",
    "            f\"Top-1 accuracy: {acc1.item()}, \"\n",
    "            f\"Top-5 accuracy: {acc5.item()}\"\n",
    "        )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pvg-experiments",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
