{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "602d9044",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T06:15:23.979869Z",
     "iopub.status.busy": "2025-01-25T06:15:23.979630Z",
     "iopub.status.idle": "2025-01-25T06:18:11.504623Z",
     "shell.execute_reply": "2025-01-25T06:18:11.503777Z"
    },
    "papermill": {
     "duration": 167.530233,
     "end_time": "2025-01-25T06:18:11.506768",
     "exception": false,
     "start_time": "2025-01-25T06:15:23.976535",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found existing installation: torch 2.1.2\r\n",
      "Uninstalling torch-2.1.2:\r\n",
      "  Successfully uninstalled torch-2.1.2\r\n",
      "Found existing installation: torchvision 0.16.2\r\n",
      "Uninstalling torchvision-0.16.2:\r\n",
      "  Successfully uninstalled torchvision-0.16.2\r\n",
      "Found existing installation: torchaudio 2.1.2\r\n",
      "Uninstalling torchaudio-2.1.2:\r\n",
      "  Successfully uninstalled torchaudio-2.1.2\r\n",
      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\r\n",
      "fastai 2.7.14 requires torch<2.3,>=1.10, but you have torch 2.5.1+cu121 which is incompatible.\u001b[0m\u001b[31m\r\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip uninstall -y torch torchvision torchaudio\n",
    "!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f40b4008",
   "metadata": {
    "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
    "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
    "execution": {
     "iopub.execute_input": "2025-01-25T06:18:11.512944Z",
     "iopub.status.busy": "2025-01-25T06:18:11.512668Z",
     "iopub.status.idle": "2025-01-25T06:18:14.447199Z",
     "shell.execute_reply": "2025-01-25T06:18:14.446553Z"
    },
    "papermill": {
     "duration": 2.939806,
     "end_time": "2025-01-25T06:18:14.449175",
     "exception": false,
     "start_time": "2025-01-25T06:18:11.509369",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import PIL\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torchvision.io import read_image, ImageReadMode\n",
    "from torchvision.datasets import ImageFolder\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader, Dataset, ConcatDataset\n",
    "from torchvision.transforms.v2 import GaussianNoise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f46aed4a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T06:18:14.454948Z",
     "iopub.status.busy": "2025-01-25T06:18:14.454622Z",
     "iopub.status.idle": "2025-01-25T06:18:14.511646Z",
     "shell.execute_reply": "2025-01-25T06:18:14.510992Z"
    },
    "papermill": {
     "duration": 0.061624,
     "end_time": "2025-01-25T06:18:14.513250",
     "exception": false,
     "start_time": "2025-01-25T06:18:14.451626",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "image_dir = \"/kaggle/input/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "03371a54",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T06:18:14.518748Z",
     "iopub.status.busy": "2025-01-25T06:18:14.518512Z",
     "iopub.status.idle": "2025-01-25T06:18:15.522340Z",
     "shell.execute_reply": "2025-01-25T06:18:15.521187Z"
    },
    "papermill": {
     "duration": 1.008508,
     "end_time": "2025-01-25T06:18:15.524057",
     "exception": false,
     "start_time": "2025-01-25T06:18:14.515549",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n",
      "100%|██████████| 44.7M/44.7M [00:00<00:00, 66.1MB/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ImageClassification(\n",
      "    crop_size=[224]\n",
      "    resize_size=[256]\n",
      "    mean=[0.485, 0.456, 0.406]\n",
      "    std=[0.229, 0.224, 0.225]\n",
      "    interpolation=InterpolationMode.BILINEAR\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "from torchvision.models import *\n",
    "\n",
    "model_name = 'resnet18_v1'\n",
    "weights = ResNet18_Weights.IMAGENET1K_V1\n",
    "model = resnet18(weights=weights)\n",
    "model_transform = weights.transforms()\n",
    "print(model_transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fd2722c2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T06:18:15.531407Z",
     "iopub.status.busy": "2025-01-25T06:18:15.531123Z",
     "iopub.status.idle": "2025-01-25T06:18:15.536817Z",
     "shell.execute_reply": "2025-01-25T06:18:15.536028Z"
    },
    "papermill": {
     "duration": 0.011232,
     "end_time": "2025-01-25T06:18:15.538371",
     "exception": false,
     "start_time": "2025-01-25T06:18:15.527139",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class ImagenetTrainClassDataset(Dataset):\n",
    "    def __init__(self, path:str, class_id:int, transform):\n",
    "        assert path.split('/')[-1] == 'train'\n",
    "        super().__init__()\n",
    "        class_names = sorted(os.listdir(path))\n",
    "        self.class_name = class_names[class_id]\n",
    "        self.class_path = path + '/' + self.class_name\n",
    "        \n",
    "        self.img_names = sorted(os.listdir(self.class_path))\n",
    "        self.transform = transform\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        img_path = self.class_path + '/' + self.img_names[idx]\n",
    "        image = PIL.Image.open(img_path).convert(\"RGB\")\n",
    "        return self.transform(image)\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.img_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d480a1c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-25T06:18:15.544783Z",
     "iopub.status.busy": "2025-01-25T06:18:15.544327Z",
     "iopub.status.idle": "2025-01-25T10:44:51.599722Z",
     "shell.execute_reply": "2025-01-25T10:44:51.598577Z"
    },
    "papermill": {
     "duration": 15996.060847,
     "end_time": "2025-01-25T10:44:51.601888",
     "exception": false,
     "start_time": "2025-01-25T06:18:15.541041",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compose(\n",
      "    ToTensor()\n",
      "    GaussianNoise(mean=0.0, sigma=0.15, clip=True)\n",
      "    ImageClassification(\n",
      "    crop_size=[224]\n",
      "    resize_size=[256]\n",
      "    mean=[0.485, 0.456, 0.406]\n",
      "    std=[0.229, 0.224, 0.225]\n",
      "    interpolation=InterpolationMode.BILINEAR\n",
      ")\n",
      ")\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:19<00:00, 50.40it/s]\n",
      "12812it [4:26:11,  1.25s/it]\n"
     ]
    }
   ],
   "source": [
    "train_path = image_dir + '/train'\n",
    "\n",
    "num_classes = 1000\n",
    "chunk_size = 1000 # number of classes in one chunk\n",
    "\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "\n",
    "for sigma in [0.3, 0.5, 0.7, 0.9]:\n",
    "    os.makedirs(f\"std{sigma}-noised-model-outputs\", exist_ok=True)\n",
    "    transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        GaussianNoise(sigma=sigma),\n",
    "        model_transform\n",
    "    ])\n",
    "    print(transform)\n",
    "    \n",
    "    # For reproducibility\n",
    "    torch.manual_seed(12)\n",
    "    np.random.seed(12)\n",
    "    random.seed(12)\n",
    "    \n",
    "    for chunk_id in range(1):\n",
    "        subsets = []\n",
    "        name_list = []\n",
    "        for i in tqdm(range(chunk_id * chunk_size, (chunk_id+1) * chunk_size)):\n",
    "            class_subset = ImagenetTrainClassDataset(train_path, class_id=i, transform=transform)\n",
    "            subsets.append(class_subset)\n",
    "            name_list += class_subset.img_names\n",
    "    \n",
    "        name_list = [name.split('.')[0] for name in name_list] # remove JPEG extension\n",
    "        subset = ConcatDataset(subsets)\n",
    "        train_dataloader = DataLoader(subset, batch_size=100, shuffle=False, num_workers=2)\n",
    "        train_probs = torch.empty((len(subset), 1000), dtype=torch.float16)\n",
    "    \n",
    "        with torch.no_grad():\n",
    "            for i, images in tqdm(enumerate(train_dataloader)):\n",
    "                images = images.to(device)\n",
    "                logits = model(images)\n",
    "                probs = F.softmax(logits, dim=1)\n",
    "                train_probs[i*100: i*100 + probs.size(0)] = probs.detach().cpu().half()\n",
    "    \n",
    "        output = {\n",
    "            'probs': train_probs,\n",
    "            'img_names': name_list\n",
    "        }\n",
    "    \n",
    "        torch.save(output, f'{f\"std{sigma}-noised-model-outputs\"}/{model_name}_train_{chunk_id}.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a5e3ef8",
   "metadata": {
    "papermill": {
     "duration": 0.452114,
     "end_time": "2025-01-25T10:44:52.509378",
     "exception": false,
     "start_time": "2025-01-25T10:44:52.057264",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kaggle": {
   "accelerator": "gpu",
   "dataSources": [
    {
     "databundleVersionId": 4225553,
     "sourceId": 6799,
     "sourceType": "competition"
    }
   ],
   "dockerImageVersionId": 30665,
   "isGpuEnabled": true,
   "isInternetEnabled": true,
   "language": "python",
   "sourceType": "notebook"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 16174.949811,
   "end_time": "2025-01-25T10:44:56.419844",
   "environment_variables": {},
   "exception": null,
   "input_path": "__notebook__.ipynb",
   "output_path": "__notebook__.ipynb",
   "parameters": {},
   "start_time": "2025-01-25T06:15:21.470033",
   "version": "2.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
