{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:10.429791Z",
     "iopub.status.busy": "2023-09-20T21:58:10.429141Z",
     "iopub.status.idle": "2023-09-20T21:58:10.577578Z",
     "shell.execute_reply": "2023-09-20T21:58:10.577066Z"
    },
    "id": "pXiB88NXEp1s"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import shutil\n",
    "import time\n",
    "from copy import deepcopy\n",
    "\n",
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:10.579630Z",
     "iopub.status.busy": "2023-09-20T21:58:10.579434Z",
     "iopub.status.idle": "2023-09-20T21:58:10.819286Z",
     "shell.execute_reply": "2023-09-20T21:58:10.818767Z"
    },
    "id": "yFPSOYa0EvYG"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import cv2\n",
    "import random\n",
    "from natsort import natsorted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:10.821336Z",
     "iopub.status.busy": "2023-09-20T21:58:10.821129Z",
     "iopub.status.idle": "2023-09-20T21:58:11.727523Z",
     "shell.execute_reply": "2023-09-20T21:58:11.727028Z"
    },
    "id": "CFYQ5nV5Eva8"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import random_split\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import Subset\n",
    "from torchvision.models.vgg import vgg19"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:11.729606Z",
     "iopub.status.busy": "2023-09-20T21:58:11.729382Z",
     "iopub.status.idle": "2023-09-20T21:58:11.732133Z",
     "shell.execute_reply": "2023-09-20T21:58:11.731701Z"
    },
    "id": "C_wXE_85Evdj"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "from torchvision.io import read_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:11.733658Z",
     "iopub.status.busy": "2023-09-20T21:58:11.733496Z",
     "iopub.status.idle": "2023-09-20T21:58:11.735868Z",
     "shell.execute_reply": "2023-09-20T21:58:11.735541Z"
    },
    "id": "68hUuTgOEvgb"
   },
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:11.843311Z",
     "iopub.status.busy": "2023-09-20T21:58:11.843178Z",
     "iopub.status.idle": "2023-09-20T21:58:11.845387Z",
     "shell.execute_reply": "2023-09-20T21:58:11.844963Z"
    },
    "id": "A09AWEm026fa"
   },
   "outputs": [],
   "source": [
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:11.901183Z",
     "iopub.status.busy": "2023-09-20T21:58:11.901049Z",
     "iopub.status.idle": "2023-09-20T21:58:11.904160Z",
     "shell.execute_reply": "2023-09-20T21:58:11.903720Z"
    },
    "id": "WKPltmdJUQxz"
   },
   "outputs": [],
   "source": [
    "class AddGaussianNoise(object):\n",
    "    def __init__(self, mean=0., std=1.):\n",
    "        self.std = std\n",
    "        self.mean = mean\n",
    "        \n",
    "    def __call__(self, tensor):\n",
    "        if random.random() > 0.5 :\n",
    "          return tensor + torch.randn(tensor.size()) * self.std + self.mean\n",
    "        else :\n",
    "          return tensor\n",
    "    \n",
    "    def __repr__(self):\n",
    "        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:11.905621Z",
     "iopub.status.busy": "2023-09-20T21:58:11.905487Z",
     "iopub.status.idle": "2023-09-20T21:58:11.907434Z",
     "shell.execute_reply": "2023-09-20T21:58:11.907105Z"
    },
    "id": "pxuVJ56w5P7S"
   },
   "outputs": [],
   "source": [
    "modify_transforms = transforms.Compose([transforms.Resize((224, 224))])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:11.908931Z",
     "iopub.status.busy": "2023-09-20T21:58:11.908797Z",
     "iopub.status.idle": "2023-09-20T21:58:11.912301Z",
     "shell.execute_reply": "2023-09-20T21:58:11.911970Z"
    },
    "id": "8SuiiuTrR_U0"
   },
   "outputs": [],
   "source": [
    "# id swap\n",
    "modify_transforms = transforms.Compose([transforms.RandomRotation(30),\n",
    "                                       transforms.Resize((224, 224)),\n",
    "                                       transforms.RandomHorizontalFlip(),\n",
    "                                       transforms.Normalize([ 93.4020,  95.5818, 113.5519] , [55.9949, 53.7595, 64.2225]),\n",
    "                                       AddGaussianNoise(0.0 , 0.1)])\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:11.913797Z",
     "iopub.status.busy": "2023-09-20T21:58:11.913661Z",
     "iopub.status.idle": "2023-09-20T21:58:11.915742Z",
     "shell.execute_reply": "2023-09-20T21:58:11.915408Z"
    },
    "id": "Azn4UKLKzWro",
    "outputId": "ac96b632-ed59-4dc9-cf2a-dc53eed3744c"
   },
   "outputs": [],
   "source": [
    "# !unzip AMSL_FaceMorphImageDataSet.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:11.917273Z",
     "iopub.status.busy": "2023-09-20T21:58:11.917117Z",
     "iopub.status.idle": "2023-09-20T21:58:11.919538Z",
     "shell.execute_reply": "2023-09-20T21:58:11.919112Z"
    },
    "id": "D8OYBERrzWuP"
   },
   "outputs": [],
   "source": [
    "# class CustomDataSet(Dataset):\n",
    "#     def __init__(self, main_dir, transform):\n",
    "#         self.main_dir = main_dir\n",
    "#         self.transform = transform\n",
    "#         #all_imgs = os.listdir(main_dir)\n",
    "#         #self.total_imgs = natsort.natsorted(all_imgs)\n",
    "#         self.total_imgs = os.listdir(main_dir)\n",
    "#         if \"_morph_\" in self.main_dir:\n",
    "#           self.labels = torch.ones(len(self.total_imgs))\n",
    "#         elif \"DeepFakes\" in self.main_dir:\n",
    "#           self.labels = torch.ones(len(self.total_imgs))\n",
    "#         else:\n",
    "#           self.labels = torch.zeros(len(self.total_imgs))\n",
    "\n",
    "#     def __len__(self):\n",
    "#         return len(self.total_imgs)\n",
    "\n",
    "#     def __getitem__(self, idx):\n",
    "#         img_loc = os.path.join(self.main_dir, self.total_imgs[idx])\n",
    "#         img = cv2.imread(img_loc)\n",
    "#         img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "#         image = cv2.resize(img, (224,224), interpolation = cv2.INTER_AREA)\n",
    "#         #image = Image.open(img_loc).convert(\"RGB\")\n",
    "#         tensor_image = torch.Tensor(image)\n",
    "#         tensor_image = torch.permute(tensor_image, (2, 0, 1))\n",
    "#         tensor_image = self.transform(tensor_image)\n",
    "#         label_img = self.labels[idx]\n",
    "#         return tensor_image , label_img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:11.921101Z",
     "iopub.status.busy": "2023-09-20T21:58:11.920945Z",
     "iopub.status.idle": "2023-09-20T21:58:11.922968Z",
     "shell.execute_reply": "2023-09-20T21:58:11.922541Z"
    }
   },
   "outputs": [],
   "source": [
    "# from facenet_pytorch import MTCNN\n",
    "# mtcnn = MTCNN(margin=40, select_largest=False, post_process=False, device='cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:11.976589Z",
     "iopub.status.busy": "2023-09-20T21:58:11.976395Z",
     "iopub.status.idle": "2023-09-20T21:58:11.980405Z",
     "shell.execute_reply": "2023-09-20T21:58:11.980074Z"
    }
   },
   "outputs": [],
   "source": [
    "class CustomDataSet(Dataset):\n",
    "    def __init__(self, main_dir, transform):\n",
    "        self.main_dir = main_dir\n",
    "        self.transform = transform\n",
    "        #self.model = model\n",
    "        all_imgs = os.listdir(main_dir)\n",
    "        self.total_imgs = natsorted(all_imgs)\n",
    "        # self.total_imgs = os.listdir(main_dir)\n",
    "        if \"original\" in self.main_dir or \"ffhq_real\" in self.main_dir or \"real\" in self.main_dir:\n",
    "          self.labels = torch.zeros(len(self.total_imgs))\n",
    "        else:\n",
    "          self.labels = torch.ones(len(self.total_imgs))\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.total_imgs)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])\n",
    "        img = cv2.imread(img_loc)\n",
    "        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "        # img = self.model(img)\n",
    "        # image = cv2.resize(img, (224,224), interpolation = cv2.INTER_AREA)\n",
    "        # image = Image.open(img_loc).convert(\"RGB\")\n",
    "        tensor_image = torch.Tensor(img)\n",
    "        tensor_image = torch.permute(tensor_image, (2, 0, 1))\n",
    "        tensor_image = self.transform(tensor_image)\n",
    "        #tensor_image = self.transform(img)\n",
    "        label_img = self.labels[idx]\n",
    "        return tensor_image , label_img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:11.988726Z",
     "iopub.status.busy": "2023-09-20T21:58:11.988570Z",
     "iopub.status.idle": "2023-09-20T21:58:12.600820Z",
     "shell.execute_reply": "2023-09-20T21:58:12.600326Z"
    }
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import ConcatDataset\n",
    "original = CustomDataSet(\"/home/ashit/Downloads/original/original_sequences/youtube/c23/images\",modify_transforms)\n",
    "faceswap = CustomDataSet(\"/home/ashit/Downloads/FaceSwap/manipulated_sequences/FaceSwap/c23/images\",modify_transforms)\n",
    "face2face = CustomDataSet(\"/home/ashit/Downloads/Face2Face/manipulated_sequences/Face2Face/c23/images\",modify_transforms)\n",
    "deepfakes = CustomDataSet(\"/home/ashit/Downloads/DeepFakeDetection/manipulated_sequences/DeepFakeDetection/c23/images\",modify_transforms)\n",
    "deepfake_originals = CustomDataSet(\"/home/ashit/Downloads/DeepFakeDetection_original/original_sequences/actors/c23/images\",modify_transforms)\n",
    "# data = ConcatDataset([original,faceswap,face2face,deepfake_originals,deepfakes])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:12.602918Z",
     "iopub.status.busy": "2023-09-20T21:58:12.602666Z",
     "iopub.status.idle": "2023-09-20T21:58:13.231035Z",
     "shell.execute_reply": "2023-09-20T21:58:13.230503Z"
    }
   },
   "outputs": [],
   "source": [
    "ffhq_train = CustomDataSet(\"/home/ashit/Downloads/ffhq_real/train\",modify_transforms)\n",
    "ffhq_test = CustomDataSet(\"/home/ashit/Downloads/ffhq_real/test\",modify_transforms)\n",
    "ffhq_val = CustomDataSet(\"/home/ashit/Downloads/ffhq_real/validation\",modify_transforms)\n",
    "faceapp_train = CustomDataSet(\"/home/ashit/Downloads/faceapp/train\",modify_transforms)\n",
    "faceapp_test = CustomDataSet(\"/home/ashit/Downloads/faceapp/test\",modify_transforms)\n",
    "faceapp_val = CustomDataSet(\"/home/ashit/Downloads/faceapp/validation\",modify_transforms)\n",
    "stargan_train = CustomDataSet(\"/home/ashit/Downloads/stargan/train\",modify_transforms)\n",
    "stargan_test = CustomDataSet(\"/home/ashit/Downloads/stargan/test\",modify_transforms)\n",
    "stargan_val = CustomDataSet(\"/home/ashit/Downloads/stargan/validation\",modify_transforms)\n",
    "stylegan_ffhq_train = CustomDataSet(\"/home/ashit/Downloads/stylegan_ffhq/train\",modify_transforms)\n",
    "stylegan_ffhq_test = CustomDataSet(\"/home/ashit/Downloads/stylegan_ffhq/test\",modify_transforms)\n",
    "stylegan_ffhq_val = CustomDataSet(\"/home/ashit/Downloads/stylegan_ffhq/validation\",modify_transforms)\n",
    "pggan_v1_train = CustomDataSet(\"/home/ashit/Downloads/pggan_v1/train\",modify_transforms)\n",
    "pggan_v1_test = CustomDataSet(\"/home/ashit/Downloads/pggan_v1/test\",modify_transforms)\n",
    "pggan_v1_val = CustomDataSet(\"/home/ashit/Downloads/pggan_v1/validation\",modify_transforms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:13.244064Z",
     "iopub.status.busy": "2023-09-20T21:58:13.243881Z",
     "iopub.status.idle": "2023-09-20T21:58:15.019419Z",
     "shell.execute_reply": "2023-09-20T21:58:15.018888Z"
    }
   },
   "outputs": [],
   "source": [
    "# id swap\n",
    "\n",
    "import torch.utils.data as data_utils\n",
    "import random\n",
    "random.seed(42)\n",
    "indices = random.sample(range(0,11972), 2500)\n",
    "original_train = data_utils.Subset(original, indices)\n",
    "indices1 = [x for x in range(0,11972) if x not in indices]\n",
    "indices2 = random.sample(indices1,1000)\n",
    "indices3 = [x for x in indices1 if x not in indices2]\n",
    "original_val = data_utils.Subset(original, indices2)\n",
    "indices4 = random.sample(indices3,2500)\n",
    "original_test = data_utils.Subset(original, indices4)\n",
    "\n",
    "indices = random.sample(range(0,8575), 2500)\n",
    "deepfake_originals_train = data_utils.Subset(deepfake_originals, indices)\n",
    "indices1 = [x for x in range(0,8575) if x not in indices]\n",
    "indices2 = random.sample(indices1,1000)\n",
    "indices3 = [x for x in indices1 if x not in indices2]\n",
    "deepfake_originals_val = data_utils.Subset(deepfake_originals, indices2)\n",
    "indices4 = random.sample(indices3,2500)\n",
    "deepfake_originals_test = data_utils.Subset(deepfake_originals, indices4)\n",
    "\n",
    "indices = random.sample(range(0,34943), 2500)\n",
    "deepfakes_train = data_utils.Subset(deepfakes, indices)\n",
    "indices1 = [x for x in range(0,34943) if x not in indices]\n",
    "indices2 = random.sample(indices1,1000)\n",
    "indices3 = [x for x in indices1 if x not in indices2]\n",
    "deepfakes_val = data_utils.Subset(deepfakes, indices2)\n",
    "indices4 = random.sample(indices3,2500)\n",
    "deepfakes_test = data_utils.Subset(deepfakes, indices4)\n",
    "\n",
    "indices = random.sample(range(0,9439), 2500)\n",
    "faceswap_train = data_utils.Subset(faceswap, indices)\n",
    "indices1 = [x for x in range(0,9439) if x not in indices]\n",
    "indices2 = random.sample(indices1,1000)\n",
    "indices3 = [x for x in indices1 if x not in indices2]\n",
    "faceswap_val = data_utils.Subset(faceswap, indices2)\n",
    "indices4 = random.sample(indices3,2500)\n",
    "faceswap_test = data_utils.Subset(faceswap, indices4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.043912Z",
     "iopub.status.busy": "2023-09-20T21:58:15.043743Z",
     "iopub.status.idle": "2023-09-20T21:58:15.046442Z",
     "shell.execute_reply": "2023-09-20T21:58:15.046119Z"
    }
   },
   "outputs": [],
   "source": [
    "# id swap\n",
    "train_data = ConcatDataset([original_train,deepfake_originals_train,deepfakes_train,faceswap_train])\n",
    "# train_data = ConcatDataset([ffhq_train,londondb,morph_train,originals,fakes_train])\n",
    "val_data = ConcatDataset([original_val,deepfake_originals_val,deepfakes_val,faceswap_val])\n",
    "test_data = ConcatDataset([original_test,deepfake_originals_test,deepfakes_test,faceswap_test])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.080446Z",
     "iopub.status.busy": "2023-09-20T21:58:15.080287Z",
     "iopub.status.idle": "2023-09-20T21:58:15.083275Z",
     "shell.execute_reply": "2023-09-20T21:58:15.082844Z"
    },
    "id": "tbxV2iLA2Dpt"
   },
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)\n",
    "val_loader = torch.utils.data.DataLoader(val_data, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.085192Z",
     "iopub.status.busy": "2023-09-20T21:58:15.085056Z",
     "iopub.status.idle": "2023-09-20T21:58:15.092779Z",
     "shell.execute_reply": "2023-09-20T21:58:15.092338Z"
    },
    "id": "qCX4ik5_zWz8"
   },
   "outputs": [],
   "source": [
    "# # model class(paper)4\n",
    "# class VGG19_CBAM(torch.nn.Module):\n",
    "\n",
    "#   # init function\n",
    "#   def __init__(self, model, num_classes=2):\n",
    "#     super().__init__()\n",
    "\n",
    "#     # pool layer\n",
    "#     self.pool = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "\n",
    "#     # parameter\n",
    "#     self.alpha = torch.nn.Parameter(torch.tensor([0.0]))\n",
    "#     self.beta = torch.nn.Parameter(torch.tensor([0.0]))\n",
    "\n",
    "#     # self.cab_1 = nn.Sequential(\n",
    "#     #     nn.Conv2d(64, 64, kernel_size=1, stride=1),\n",
    "#     #     nn.PReLU(),\n",
    "#     #     nn.BatchNorm2d(64)\n",
    "#     # )\n",
    "    \n",
    "#     # self.cab_3 = nn.Sequential(\n",
    "#     #     nn.Conv2d(128, 128, kernel_size=1, stride=1),\n",
    "#     #     nn.PReLU(),\n",
    "#     #     nn.BatchNorm2d(128)\n",
    "#     # )\n",
    "    \n",
    "#     # self.cab_5 = nn.Sequential(\n",
    "#     #     nn.Conv2d(256, 256, kernel_size=1, stride=1),\n",
    "#     #     nn.PReLU(),\n",
    "#     #     nn.BatchNorm2d(256)\n",
    "#     # )\n",
    "    \n",
    "#     self.cab_9 = nn.Sequential(\n",
    "#         nn.Conv2d(512, 512, kernel_size=1, stride=1),\n",
    "#         nn.PReLU(),\n",
    "#         nn.BatchNorm2d(512)\n",
    "#     )\n",
    "    \n",
    "\n",
    "#     # spatial attention\n",
    "#     self.spatial_attention = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(2, 1, kernel_size=3, padding=1, stride=1),\n",
    "#         torch.nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(1)\n",
    "#     )\n",
    "\n",
    "#     # channel attention\n",
    "#     self.max_pool_1 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=224, stride=224))\n",
    "#     self.max_pool_2 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=112, stride=112))\n",
    "#     self.max_pool_3 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=56, stride=56))\n",
    "#     self.max_pool_4 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=28, stride=28))\n",
    "#     self.max_pool_5 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=14, stride=14))\n",
    "#     self.avg_pool_1 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=224, stride=224))\n",
    "#     self.avg_pool_2 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=112, stride=112))\n",
    "#     self.avg_pool_3 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=56, stride=56))\n",
    "#     self.avg_pool_4 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=28, stride=28))\n",
    "#     self.avg_pool_5 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=14, stride=14))\n",
    "\n",
    "#     # features without considering vgg19 pooling layer\n",
    "#     self.features_1 = torch.nn.Sequential(*list(model.features.children())[:3])\n",
    "#     self.features_2 = torch.nn.Sequential(*list(model.features.children())[3:6])\n",
    "#     self.features_3 = torch.nn.Sequential(*list(model.features.children())[7:10])\n",
    "#     self.features_4 = torch.nn.Sequential(*list(model.features.children())[10:13])\n",
    "#     self.features_5 = torch.nn.Sequential(*list(model.features.children())[14:17])\n",
    "#     self.features_6 = torch.nn.Sequential(*list(model.features.children())[17:20])\n",
    "#     self.features_7 = torch.nn.Sequential(*list(model.features.children())[20:23])\n",
    "#     self.features_8 = torch.nn.Sequential(*list(model.features.children())[23:26])\n",
    "#     self.features_9 = torch.nn.Sequential(*list(model.features.children())[27:30])\n",
    "#     self.features_10 = torch.nn.Sequential(*list(model.features.children())[30:33])\n",
    "#     self.features_11 = torch.nn.Sequential(*list(model.features.children())[33:36])\n",
    "#     self.features_12 = torch.nn.Sequential(*list(model.features.children())[36:39])\n",
    "#     self.features_13 = torch.nn.Sequential(*list(model.features.children())[40:43])\n",
    "#     self.features_14 = torch.nn.Sequential(*list(model.features.children())[43:46])\n",
    "#     self.features_15 = torch.nn.Sequential(*list(model.features.children())[46:49])\n",
    "#     self.features_16 = torch.nn.Sequential(*list(model.features.children())[49:52])\n",
    "\n",
    "#     self.avgpool = nn.AdaptiveAvgPool2d(7)\n",
    "\n",
    "#     # classifier\n",
    "#     self.classifier = torch.nn.Sequential(\n",
    "#         torch.nn.Linear(25088, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 2)\n",
    "#     )\n",
    "\n",
    "\n",
    "#   # forward\n",
    "#   def forward(self, x):\n",
    "#     x = self.features_1(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_1(x) + self.avg_pool_1(x)\n",
    "#     # scale = self.cab_1(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.bmm(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.bmm(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))).cpu()\n",
    "#     # scale = torch.nn.functional.softmax(scale)\n",
    "#     # temp_x = torch.bmm(torch.reshape(x.cpu(), (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # temp_x = self.beta.cpu() * temp_x\n",
    "#     # x = (x + torch.reshape(temp_x, (b,c,h,w))).cuda()\n",
    "#     # self.beta = self.beta.cuda()\n",
    "\n",
    "#     x = self.features_2(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_1(x) + self.avg_pool_1(x)\n",
    "#     # scale = self.cab_1(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.bmm(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.nn.functional.softmax(torch.bmm(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     # temp_x = self.beta * torch.bmm(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_3(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_2(x) + self.avg_pool_2(x)\n",
    "#     # scale = self.cab_3(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.matmul(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     # temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_4(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_2(x) + self.avg_pool_2(x)\n",
    "#     # scale = self.cab_3(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.matmul(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     # temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_5(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_3(x) + self.avg_pool_3(x)\n",
    "#     # scale = self.cab_5(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.matmul(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     # temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_6(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_3(x) + self.avg_pool_3(x)\n",
    "#     # scale = self.cab_5(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.matmul(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     # temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_7(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_3(x) + self.avg_pool_3(x)\n",
    "#     # scale = self.cab_5(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.matmul(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     # temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_8(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_3(x) + self.avg_pool_3(x)\n",
    "#     # scale = self.cab_5(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.matmul(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     # temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_9(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_4(x) + self.avg_pool_4(x)\n",
    "#     # scale = self.cab_9(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.matmul(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     # temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_10(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_4(x) + self.avg_pool_4(x)\n",
    "#     # scale = self.cab_9(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.matmul(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     # temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_11(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_4(x) + self.avg_pool_4(x)\n",
    "#     # scale = self.cab_9(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.matmul(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     # temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_12(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_4(x) + self.avg_pool_4(x)\n",
    "#     # scale = self.cab_9(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.matmul(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     # temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_13(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_5(x) + self.avg_pool_5(x)\n",
    "#     # scale = self.cab_9(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.matmul(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w)))\n",
    "#     # scale = torch.nn.functional.softmax(scale, dim = 1)\n",
    "#     # temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_14(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_5(x) + self.avg_pool_5(x)\n",
    "#     # scale = self.cab_9(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.matmul(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     # temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_15(x)\n",
    "#     # b, c, h, w = x.size()\n",
    "#     # scale = self.max_pool_5(x) + self.avg_pool_5(x)\n",
    "#     # scale = self.cab_9(scale).squeeze(3)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     # temp_x = self.alpha * torch.matmul(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     # x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     # scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     # scale = self.spatial_attention(scale)\n",
    "#     # scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     # temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     # x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_16(x)\n",
    "#     b, c, h, w = x.size()\n",
    "#     scale = self.max_pool_5(x) + self.avg_pool_5(x)\n",
    "#     scale = self.cab_9(scale).squeeze(3)\n",
    "#     scale = torch.nn.functional.softmax(torch.matmul(scale,torch.transpose(scale,1,2)))\n",
    "#     temp_x = self.alpha * torch.matmul(scale,torch.reshape(x, (b,c,h*w)))\n",
    "#     x = x + torch.reshape(temp_x,(b,c,h,w))\n",
    "#     scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     scale = self.spatial_attention(scale)\n",
    "#     scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.avgpool(x)\n",
    "#     x = x.view(x.shape[0], -1)\n",
    "#     x = self.classifier(x)\n",
    "#     return x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.094412Z",
     "iopub.status.busy": "2023-09-20T21:58:15.094275Z",
     "iopub.status.idle": "2023-09-20T21:58:15.098268Z",
     "shell.execute_reply": "2023-09-20T21:58:15.097945Z"
    },
    "id": "dDpCwEVwgnMm"
   },
   "outputs": [],
   "source": [
    "# # model class(paper)4 spatial\n",
    "# class VGG19_CBAM(torch.nn.Module):\n",
    "\n",
    "#   # init function\n",
    "#   def __init__(self, model, num_classes=2):\n",
    "#     super().__init__()\n",
    "\n",
    "#     # pool layer\n",
    "#     self.pool = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "\n",
    "#     # parameter\n",
    "#     #self.alpha = torch.nn.Parameter(torch.tensor([0.0]))\n",
    "#     self.beta = torch.nn.Parameter(torch.tensor([0.0]))\n",
    "\n",
    "    \n",
    "    \n",
    "\n",
    "#     # spatial attention\n",
    "#     self.spatial_attention = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(2, 1, kernel_size=3, padding=1, stride=1),\n",
    "#         torch.nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(1)\n",
    "#     )\n",
    "\n",
    "#     # channel attention\n",
    "#     self.max_pool_1 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=224, stride=224))\n",
    "#     self.max_pool_2 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=112, stride=112))\n",
    "#     self.max_pool_3 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=56, stride=56))\n",
    "#     self.max_pool_4 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=28, stride=28))\n",
    "#     self.max_pool_5 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=14, stride=14))\n",
    "#     self.avg_pool_1 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=224, stride=224))\n",
    "#     self.avg_pool_2 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=112, stride=112))\n",
    "#     self.avg_pool_3 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=56, stride=56))\n",
    "#     self.avg_pool_4 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=28, stride=28))\n",
    "#     self.avg_pool_5 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=14, stride=14))\n",
    "\n",
    "#     # features without considering vgg19 pooling layer\n",
    "#     self.features_1 = torch.nn.Sequential(*list(model.features.children())[:3])\n",
    "#     self.features_2 = torch.nn.Sequential(*list(model.features.children())[3:6])\n",
    "#     self.features_3 = torch.nn.Sequential(*list(model.features.children())[7:10])\n",
    "#     self.features_4 = torch.nn.Sequential(*list(model.features.children())[10:13])\n",
    "#     self.features_5 = torch.nn.Sequential(*list(model.features.children())[14:17])\n",
    "#     self.features_6 = torch.nn.Sequential(*list(model.features.children())[17:20])\n",
    "#     self.features_7 = torch.nn.Sequential(*list(model.features.children())[20:23])\n",
    "#     self.features_8 = torch.nn.Sequential(*list(model.features.children())[23:26])\n",
    "#     self.features_9 = torch.nn.Sequential(*list(model.features.children())[27:30])\n",
    "#     self.features_10 = torch.nn.Sequential(*list(model.features.children())[30:33])\n",
    "#     self.features_11 = torch.nn.Sequential(*list(model.features.children())[33:36])\n",
    "#     self.features_12 = torch.nn.Sequential(*list(model.features.children())[36:39])\n",
    "#     self.features_13 = torch.nn.Sequential(*list(model.features.children())[40:43])\n",
    "#     self.features_14 = torch.nn.Sequential(*list(model.features.children())[43:46])\n",
    "#     self.features_15 = torch.nn.Sequential(*list(model.features.children())[46:49])\n",
    "#     self.features_16 = torch.nn.Sequential(*list(model.features.children())[49:52])\n",
    "\n",
    "#     self.avgpool = nn.AdaptiveAvgPool2d(7)\n",
    "\n",
    "#     # classifier\n",
    "#     self.classifier = torch.nn.Sequential(\n",
    "#         torch.nn.Linear(25088, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 2),\n",
    "#         #torch.nn.Softmax(1),\n",
    "#     )\n",
    "\n",
    "\n",
    "#   # forward\n",
    "#   def forward(self, x):\n",
    "#     x = self.features_1(x)\n",
    "    \n",
    "#     #scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "    \n",
    "\n",
    "#     x = self.features_2(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_3(x)\n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_5(x)\n",
    "    \n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_6(x)\n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_7(x)\n",
    "    \n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_8(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_9(x)\n",
    "    \n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_10(x)\n",
    "    \n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_11(x)\n",
    "    \n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_12(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_13(x)\n",
    "#     b, c, h, w = x.size()\n",
    "#     scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     scale = self.spatial_attention(scale)\n",
    "#     scale = torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w)))\n",
    "#     scale = torch.nn.functional.softmax(scale, dim = 1)\n",
    "#     temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_14(x)\n",
    "#     b, c, h, w = x.size()\n",
    "#     scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     scale = self.spatial_attention(scale)\n",
    "#     scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_15(x)\n",
    "#     b, c, h, w = x.size()\n",
    "#     scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     scale = self.spatial_attention(scale)\n",
    "#     scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_16(x)\n",
    "#     b, c, h, w = x.size()\n",
    "#     scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     scale = self.spatial_attention(scale)\n",
    "#     scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.avgpool(x)\n",
    "#     x = x.view(x.shape[0], -1)\n",
    "#     x = self.classifier(x)\n",
    "#     return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.099823Z",
     "iopub.status.busy": "2023-09-20T21:58:15.099687Z",
     "iopub.status.idle": "2023-09-20T21:58:15.103483Z",
     "shell.execute_reply": "2023-09-20T21:58:15.103050Z"
    },
    "id": "zHf1gtI48sad"
   },
   "outputs": [],
   "source": [
    "# # model class(paper)4 spatial modified\n",
    "# class VGG19_CBAM(torch.nn.Module):\n",
    "\n",
    "#   # init function\n",
    "#   def __init__(self, model, num_classes=2):\n",
    "#     super().__init__()\n",
    "\n",
    "#     # pool layer\n",
    "#     self.pool = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "\n",
    "#     # parameter\n",
    "#     #self.alpha = torch.nn.Parameter(torch.tensor([0.0]))\n",
    "#     self.beta = torch.nn.Parameter(torch.tensor([0.0]))\n",
    "\n",
    "    \n",
    "    \n",
    "\n",
    "#     # spatial attention\n",
    "#     self.spatial_attention = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(2, 1, kernel_size=3, padding=1, stride=1),\n",
    "#         torch.nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(1)\n",
    "#     )\n",
    "\n",
    "#     # channel attention\n",
    "#     self.max_pool_1 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=224, stride=224))\n",
    "#     self.max_pool_2 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=112, stride=112))\n",
    "#     self.max_pool_3 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=56, stride=56))\n",
    "#     self.max_pool_4 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=28, stride=28))\n",
    "#     self.max_pool_5 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=14, stride=14))\n",
    "#     self.avg_pool_1 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=224, stride=224))\n",
    "#     self.avg_pool_2 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=112, stride=112))\n",
    "#     self.avg_pool_3 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=56, stride=56))\n",
    "#     self.avg_pool_4 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=28, stride=28))\n",
    "#     self.avg_pool_5 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=14, stride=14))\n",
    "\n",
    "#     # features without considering vgg19 pooling layer\n",
    "#     self.features_1 = torch.nn.Sequential(*list(model.features.children())[:3])\n",
    "#     self.features_2 = torch.nn.Sequential(*list(model.features.children())[3:6])\n",
    "#     self.features_3 = torch.nn.Sequential(*list(model.features.children())[7:10])\n",
    "#     self.features_4 = torch.nn.Sequential(*list(model.features.children())[10:13])\n",
    "#     self.features_5 = torch.nn.Sequential(*list(model.features.children())[14:17])\n",
    "#     self.features_6 = torch.nn.Sequential(*list(model.features.children())[17:20])\n",
    "#     self.features_7 = torch.nn.Sequential(*list(model.features.children())[20:23])\n",
    "#     self.features_8 = torch.nn.Sequential(*list(model.features.children())[23:26])\n",
    "#     self.features_9 = torch.nn.Sequential(*list(model.features.children())[27:30])\n",
    "#     self.features_10 = torch.nn.Sequential(*list(model.features.children())[30:33])\n",
    "#     self.features_11 = torch.nn.Sequential(*list(model.features.children())[33:36])\n",
    "#     self.features_12 = torch.nn.Sequential(*list(model.features.children())[36:39])\n",
    "#     self.features_13 = torch.nn.Sequential(*list(model.features.children())[40:43])\n",
    "#     self.features_14 = torch.nn.Sequential(*list(model.features.children())[43:46])\n",
    "#     self.features_15 = torch.nn.Sequential(*list(model.features.children())[46:49])\n",
    "#     self.features_16 = torch.nn.Sequential(*list(model.features.children())[49:52])\n",
    "\n",
    "#     self.avgpool = nn.AdaptiveAvgPool2d(7)\n",
    "\n",
    "#     # classifier\n",
    "#     self.classifier = torch.nn.Sequential(\n",
    "#         torch.nn.Linear(25088, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 2),\n",
    "#         #torch.nn.Softmax(1),\n",
    "#     )\n",
    "\n",
    "\n",
    "#   # forward\n",
    "#   def forward(self, x):\n",
    "#     x = self.features_1(x)\n",
    "    \n",
    "#     #scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "    \n",
    "\n",
    "#     x = self.features_2(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_3(x)\n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_5(x)\n",
    "    \n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_6(x)\n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_7(x)\n",
    "    \n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_8(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_9(x)\n",
    "    \n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_10(x)\n",
    "    \n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_11(x)\n",
    "    \n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_12(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_13(x)\n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_14(x)\n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_15(x)\n",
    "    \n",
    "\n",
    "#     #b, c, h, w = x.size()\n",
    "#     x = self.features_16(x)\n",
    "#     b, c, h, w = x.size()\n",
    "#     scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     scale = self.spatial_attention(scale)\n",
    "#     scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     #temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),scale)\n",
    "#     temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.avgpool(x)\n",
    "#     x = x.view(x.shape[0], -1)\n",
    "#     x = self.classifier(x)\n",
    "#     return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.105075Z",
     "iopub.status.busy": "2023-09-20T21:58:15.104940Z",
     "iopub.status.idle": "2023-09-20T21:58:15.108745Z",
     "shell.execute_reply": "2023-09-20T21:58:15.108296Z"
    },
    "id": "MbxY_ZK_hFf4"
   },
   "outputs": [],
   "source": [
    "# #soft attention + spatial\n",
    "# class VGG19_CBAM(torch.nn.Module):\n",
    "\n",
    "#   # init function\n",
    "#   def __init__(self, model, num_classes=2):\n",
    "#     super().__init__()\n",
    "\n",
    "#     # pool layer\n",
    "#     self.pool = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "\n",
    "#     self.beta = torch.nn.Parameter(torch.tensor([0.0]))\n",
    "\n",
    "#     # spatial attention\n",
    "#     self.soft_attention = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(512, 512, kernel_size=1, stride=1),\n",
    "#         #nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(512),\n",
    "#         #torch.nn.Sigmoid()\n",
    "#     )\n",
    "#     # spatial attention\n",
    "#     self.spatial_attention = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(2, 1, kernel_size=3, padding=1, stride=1),\n",
    "#         torch.nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(1)\n",
    "#     )\n",
    "\n",
    "#     # channel attention\n",
    "#     self.max_pool_1 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=224, stride=224))\n",
    "#     self.max_pool_2 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=112, stride=112))\n",
    "#     self.max_pool_3 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=56, stride=56))\n",
    "#     self.max_pool_4 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=28, stride=28))\n",
    "#     self.max_pool_5 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=14, stride=14))\n",
    "#     self.avg_pool_1 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=224, stride=224))\n",
    "#     self.avg_pool_2 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=112, stride=112))\n",
    "#     self.avg_pool_3 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=56, stride=56))\n",
    "#     self.avg_pool_4 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=28, stride=28))\n",
    "#     self.avg_pool_5 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=14, stride=14))\n",
    "\n",
    "#     # features without considering vgg19 pooling layer\n",
    "#     self.features_1 = torch.nn.Sequential(*list(model.features.children())[:3])\n",
    "#     self.features_2 = torch.nn.Sequential(*list(model.features.children())[3:6])\n",
    "#     self.features_3 = torch.nn.Sequential(*list(model.features.children())[7:10])\n",
    "#     self.features_4 = torch.nn.Sequential(*list(model.features.children())[10:13])\n",
    "#     self.features_5 = torch.nn.Sequential(*list(model.features.children())[14:17])\n",
    "#     self.features_6 = torch.nn.Sequential(*list(model.features.children())[17:20])\n",
    "#     self.features_7 = torch.nn.Sequential(*list(model.features.children())[20:23])\n",
    "#     self.features_8 = torch.nn.Sequential(*list(model.features.children())[23:26])\n",
    "#     self.features_9 = torch.nn.Sequential(*list(model.features.children())[27:30])\n",
    "#     self.features_10 = torch.nn.Sequential(*list(model.features.children())[30:33])\n",
    "#     self.features_11 = torch.nn.Sequential(*list(model.features.children())[33:36])\n",
    "#     self.features_12 = torch.nn.Sequential(*list(model.features.children())[36:39])\n",
    "#     self.features_13 = torch.nn.Sequential(*list(model.features.children())[40:43])\n",
    "#     self.features_14 = torch.nn.Sequential(*list(model.features.children())[43:46])\n",
    "#     self.features_15 = torch.nn.Sequential(*list(model.features.children())[46:49])\n",
    "#     self.features_16 = torch.nn.Sequential(torch.nn.Conv2d(512*2, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),\n",
    "#                                            torch.nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),\n",
    "#                                            torch.nn.ReLU(inplace=True))\n",
    "    \n",
    "\n",
    "#     self.avgpool = nn.AdaptiveAvgPool2d(7)\n",
    "\n",
    "#     # classifier\n",
    "#     self.classifier = torch.nn.Sequential(\n",
    "#         torch.nn.Linear(25088, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 2)\n",
    "#     )\n",
    "\n",
    "\n",
    "#   # forward\n",
    "#   def forward(self, x):\n",
    "#     x = self.features_1(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_2(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_3(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_4(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_5(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_6(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_7(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_8(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_9(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_10(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_11(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_12(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_13(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_14(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_15(x)\n",
    "#     x_conv = self.soft_attention(x)\n",
    "#     w_xy = self.soft_attention(x)\n",
    "#     y_conv = self.soft_attention(w_xy)\n",
    "#     temp_conv = torch.nn.functional.sigmoid(self.soft_attention(x_conv + y_conv))\n",
    "#     temp = w_xy + (w_xy * temp_conv)\n",
    "#     x = torch.cat((x,temp), dim =1)\n",
    "#     b, c, h, w = x.size()\n",
    "#     scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     scale = self.spatial_attention(scale)\n",
    "#     scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "    \n",
    "#     x = self.features_16(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.avgpool(x)\n",
    "#     x = x.view(x.shape[0], -1)\n",
    "#     x = self.classifier(x)\n",
    "#     return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.110346Z",
     "iopub.status.busy": "2023-09-20T21:58:15.110211Z",
     "iopub.status.idle": "2023-09-20T21:58:15.114091Z",
     "shell.execute_reply": "2023-09-20T21:58:15.113653Z"
    },
    "id": "yS4sHyn4e4PK"
   },
   "outputs": [],
   "source": [
    "# #soft attention + spatial modidified\n",
    "# class VGG19_CBAM(torch.nn.Module):\n",
    "\n",
    "#   # init function\n",
    "#   def __init__(self, model, num_classes=2):\n",
    "#     super().__init__()\n",
    "\n",
    "#     # pool layer\n",
    "#     self.pool = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "\n",
    "#     self.beta = torch.nn.Parameter(torch.tensor([0.0]))\n",
    "\n",
    "#     # spatial attention\n",
    "#     self.soft_attention = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(512, 512, kernel_size=1, stride=1),\n",
    "#         #nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(512),\n",
    "#         #torch.nn.Sigmoid()\n",
    "#     )\n",
    "#     # spatial attention\n",
    "#     self.spatial_attention = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(2, 1, kernel_size=3, padding=1, stride=1),\n",
    "#         torch.nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(1)\n",
    "#     )\n",
    "\n",
    "#     # channel attention\n",
    "#     self.max_pool_1 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=224, stride=224))\n",
    "#     self.max_pool_2 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=112, stride=112))\n",
    "#     self.max_pool_3 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=56, stride=56))\n",
    "#     self.max_pool_4 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=28, stride=28))\n",
    "#     self.max_pool_5 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=14, stride=14))\n",
    "#     self.avg_pool_1 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=224, stride=224))\n",
    "#     self.avg_pool_2 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=112, stride=112))\n",
    "#     self.avg_pool_3 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=56, stride=56))\n",
    "#     self.avg_pool_4 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=28, stride=28))\n",
    "#     self.avg_pool_5 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=14, stride=14))\n",
    "\n",
    "#     # features without considering vgg19 pooling layer\n",
    "#     self.features_1 = torch.nn.Sequential(*list(model.features.children())[:3])\n",
    "#     self.features_2 = torch.nn.Sequential(*list(model.features.children())[3:6])\n",
    "#     self.features_3 = torch.nn.Sequential(*list(model.features.children())[7:10])\n",
    "#     self.features_4 = torch.nn.Sequential(*list(model.features.children())[10:13])\n",
    "#     self.features_5 = torch.nn.Sequential(*list(model.features.children())[14:17])\n",
    "#     self.features_6 = torch.nn.Sequential(*list(model.features.children())[17:20])\n",
    "#     self.features_7 = torch.nn.Sequential(*list(model.features.children())[20:23])\n",
    "#     self.features_8 = torch.nn.Sequential(*list(model.features.children())[23:26])\n",
    "#     self.features_9 = torch.nn.Sequential(*list(model.features.children())[27:30])\n",
    "#     self.features_10 = torch.nn.Sequential(*list(model.features.children())[30:33])\n",
    "#     self.features_11 = torch.nn.Sequential(*list(model.features.children())[33:36])\n",
    "#     self.features_12 = torch.nn.Sequential(*list(model.features.children())[36:39])\n",
    "#     self.features_13 = torch.nn.Sequential(*list(model.features.children())[40:43])\n",
    "#     self.features_14 = torch.nn.Sequential(*list(model.features.children())[43:46])\n",
    "#     self.features_15 = torch.nn.Sequential(*list(model.features.children())[46:49])\n",
    "#     self.features_16 = torch.nn.Sequential(torch.nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),\n",
    "#                                            torch.nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),\n",
    "#                                            torch.nn.ReLU(inplace=True))\n",
    "    \n",
    "\n",
    "#     self.avgpool = nn.AdaptiveAvgPool2d(7)\n",
    "\n",
    "#     # classifier\n",
    "#     self.classifier = torch.nn.Sequential(\n",
    "#         torch.nn.Linear(25088*2, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 2)\n",
    "#     )\n",
    "\n",
    "\n",
    "#   # forward\n",
    "#   def forward(self, x):\n",
    "#     x = self.features_1(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_2(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_3(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_4(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_5(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_6(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_7(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_8(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_9(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_10(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_11(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_12(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_13(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_14(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_15(x)\n",
    "    \n",
    "    \n",
    "#     x = self.features_16(x)\n",
    "#     x_conv = self.soft_attention(x)\n",
    "#     w_xy = self.soft_attention(x)\n",
    "#     y_conv = self.soft_attention(w_xy)\n",
    "#     temp_conv = torch.nn.functional.sigmoid(self.soft_attention(x_conv + y_conv))\n",
    "#     temp = w_xy + (w_xy * temp_conv)\n",
    "#     x = torch.cat((x,temp), dim =1)\n",
    "#     b, c, h, w = x.size()\n",
    "#     scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     scale = self.spatial_attention(scale)\n",
    "#     scale = torch.nn.functional.softmax(torch.matmul(torch.transpose(torch.reshape(scale, (b,1,h*w)),1,2),torch.reshape(scale, (b,1,h*w))), dim = 1)\n",
    "#     temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),torch.transpose(scale,1,2))\n",
    "#     #temp_x = self.beta * torch.matmul(torch.reshape(x, (b,c,h*w)),scale)\n",
    "#     x = x + torch.reshape(temp_x, (b,c,h,w))\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.avgpool(x)\n",
    "#     x = x.view(x.shape[0], -1)\n",
    "#     x = self.classifier(x)\n",
    "#     return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.115612Z",
     "iopub.status.busy": "2023-09-20T21:58:15.115478Z",
     "iopub.status.idle": "2023-09-20T21:58:15.119295Z",
     "shell.execute_reply": "2023-09-20T21:58:15.118863Z"
    },
    "id": "4JeRkwzUUvdX"
   },
   "outputs": [],
   "source": [
    "# #modified attention 1\n",
    "# class VGG19_CBAM(torch.nn.Module):\n",
    "\n",
    "#   # init function\n",
    "#   def __init__(self, model, num_classes=2):\n",
    "#     super().__init__()\n",
    "\n",
    "#     # pool layer\n",
    "#     self.pool = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "\n",
    "#     #self.beta = torch.nn.Parameter(torch.tensor([0.0]))\n",
    "\n",
    "#     # channel attention\n",
    "#     self.attention_x = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(512, 512, kernel_size=1, stride=1),\n",
    "#         nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(512),\n",
    "#         #torch.nn.Sigmoid()\n",
    "#     )\n",
    "#     self.attention_y = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(512, 512, kernel_size=1, stride=1),\n",
    "#         nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(512),\n",
    "#         #torch.nn.Sigmoid()\n",
    "#     )\n",
    "#     self.attention_xy = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(512, 512, kernel_size=1, stride=1),\n",
    "#         #nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(512),\n",
    "#         torch.nn.Sigmoid()\n",
    "#     )\n",
    "#     # spatial attention\n",
    "#     self.spatial_attention = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(2, 1, kernel_size=3, padding=1, stride=1),\n",
    "#         #torch.nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(1),\n",
    "#         torch.nn.Sigmoid()\n",
    "#     )\n",
    "\n",
    "#     # channel attention\n",
    "#     self.max_pool_1 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=224, stride=224))\n",
    "#     self.max_pool_2 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=112, stride=112))\n",
    "#     self.max_pool_3 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=56, stride=56))\n",
    "#     self.max_pool_4 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=28, stride=28))\n",
    "#     self.max_pool_5 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=14, stride=14))\n",
    "#     self.avg_pool_1 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=224, stride=224))\n",
    "#     self.avg_pool_2 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=112, stride=112))\n",
    "#     self.avg_pool_3 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=56, stride=56))\n",
    "#     self.avg_pool_4 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=28, stride=28))\n",
    "#     self.avg_pool_5 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=14, stride=14))\n",
    "\n",
    "#     # features without considering vgg19 pooling layer\n",
    "#     self.features_1 = torch.nn.Sequential(*list(model.features.children())[:3])\n",
    "#     self.features_2 = torch.nn.Sequential(*list(model.features.children())[3:6])\n",
    "#     self.features_3 = torch.nn.Sequential(*list(model.features.children())[7:10])\n",
    "#     self.features_4 = torch.nn.Sequential(*list(model.features.children())[10:13])\n",
    "#     self.features_5 = torch.nn.Sequential(*list(model.features.children())[14:17])\n",
    "#     self.features_6 = torch.nn.Sequential(*list(model.features.children())[17:20])\n",
    "#     self.features_7 = torch.nn.Sequential(*list(model.features.children())[20:23])\n",
    "#     self.features_8 = torch.nn.Sequential(*list(model.features.children())[23:26])\n",
    "#     self.features_9 = torch.nn.Sequential(*list(model.features.children())[27:30])\n",
    "#     self.features_10 = torch.nn.Sequential(*list(model.features.children())[30:33])\n",
    "#     self.features_11 = torch.nn.Sequential(*list(model.features.children())[33:36])\n",
    "#     self.features_12 = torch.nn.Sequential(*list(model.features.children())[36:39])\n",
    "#     self.features_13 = torch.nn.Sequential(*list(model.features.children())[40:43])\n",
    "#     self.features_14 = torch.nn.Sequential(*list(model.features.children())[43:46])\n",
    "#     self.features_15 = torch.nn.Sequential(*list(model.features.children())[46:49])\n",
    "#     self.features_16 = torch.nn.Sequential(torch.nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),\n",
    "#                                            torch.nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),\n",
    "#                                            torch.nn.ReLU(inplace=True))\n",
    "    \n",
    "\n",
    "#     self.avgpool = nn.AdaptiveAvgPool2d(7)\n",
    "\n",
    "#     # classifier\n",
    "#     self.classifier = torch.nn.Sequential(\n",
    "#         torch.nn.Linear(25088, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 2)\n",
    "#     )\n",
    "\n",
    "\n",
    "#   # forward\n",
    "#   def forward(self, x):\n",
    "#     x = self.features_1(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_2(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_3(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_4(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_5(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_6(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_7(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_8(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_9(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_10(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_11(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_12(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_13(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_14(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_15(x)\n",
    "    \n",
    "    \n",
    "    \n",
    "#     x = self.features_16(x)\n",
    "#     x_conv = self.attention_x(x)\n",
    "#     #w_xy = self.soft_attention(x)\n",
    "#     y_conv = self.attention_y(x)\n",
    "#     temp_conv = self.attention_xy(x_conv + y_conv)\n",
    "#     x = x_conv + (x_conv * temp_conv)\n",
    "#     #x = torch.cat((x,temp), dim =1)\n",
    "#     b, c, h, w = x.size()\n",
    "#     scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     scale = self.spatial_attention(scale)\n",
    "#     #scale = torch.nn.functional.sigmoid(scale)\n",
    "#     x = x + (x * scale)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.avgpool(x)\n",
    "#     x = x.view(x.shape[0], -1)\n",
    "#     x = self.classifier(x)\n",
    "#     return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.120854Z",
     "iopub.status.busy": "2023-09-20T21:58:15.120720Z",
     "iopub.status.idle": "2023-09-20T21:58:15.124609Z",
     "shell.execute_reply": "2023-09-20T21:58:15.124236Z"
    },
    "id": "BqWjOp3ZgtN2"
   },
   "outputs": [],
   "source": [
    "# #modified attention 1 spatial\n",
    "# class VGG19_CBAM(torch.nn.Module):\n",
    "\n",
    "#   # init function\n",
    "#   def __init__(self, model, num_classes=2):\n",
    "#     super().__init__()\n",
    "\n",
    "#     # pool layer\n",
    "#     self.pool = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "\n",
    "#     #self.beta = torch.nn.Parameter(torch.tensor([0.0]))\n",
    "\n",
    "#     # channel attention\n",
    "#     self.attention_x = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(512, 512, kernel_size=1, stride=1),\n",
    "#         nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(512),\n",
    "#         #torch.nn.Sigmoid()\n",
    "#     )\n",
    "#     self.attention_y = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(512, 512, kernel_size=1, stride=1),\n",
    "#         nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(512),\n",
    "#         #torch.nn.Sigmoid()\n",
    "#     )\n",
    "#     self.attention_xy = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(512, 512, kernel_size=1, stride=1),\n",
    "#         #nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(512),\n",
    "#         torch.nn.Sigmoid()\n",
    "#     )\n",
    "#     # spatial attention\n",
    "#     self.spatial_attention = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(2, 1, kernel_size=3, padding=1, stride=1),\n",
    "#         #torch.nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(1),\n",
    "#         torch.nn.Sigmoid()\n",
    "#     )\n",
    "\n",
    "#     # channel attention\n",
    "#     self.max_pool_1 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=224, stride=224))\n",
    "#     self.max_pool_2 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=112, stride=112))\n",
    "#     self.max_pool_3 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=56, stride=56))\n",
    "#     self.max_pool_4 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=28, stride=28))\n",
    "#     self.max_pool_5 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=14, stride=14))\n",
    "#     self.avg_pool_1 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=224, stride=224))\n",
    "#     self.avg_pool_2 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=112, stride=112))\n",
    "#     self.avg_pool_3 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=56, stride=56))\n",
    "#     self.avg_pool_4 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=28, stride=28))\n",
    "#     self.avg_pool_5 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=14, stride=14))\n",
    "\n",
    "#     # features without considering vgg19 pooling layer\n",
    "#     self.features_1 = torch.nn.Sequential(*list(model.features.children())[:3])\n",
    "#     self.features_2 = torch.nn.Sequential(*list(model.features.children())[3:6])\n",
    "#     self.features_3 = torch.nn.Sequential(*list(model.features.children())[7:10])\n",
    "#     self.features_4 = torch.nn.Sequential(*list(model.features.children())[10:13])\n",
    "#     self.features_5 = torch.nn.Sequential(*list(model.features.children())[14:17])\n",
    "#     self.features_6 = torch.nn.Sequential(*list(model.features.children())[17:20])\n",
    "#     self.features_7 = torch.nn.Sequential(*list(model.features.children())[20:23])\n",
    "#     self.features_8 = torch.nn.Sequential(*list(model.features.children())[23:26])\n",
    "#     self.features_9 = torch.nn.Sequential(*list(model.features.children())[27:30])\n",
    "#     self.features_10 = torch.nn.Sequential(*list(model.features.children())[30:33])\n",
    "#     self.features_11 = torch.nn.Sequential(*list(model.features.children())[33:36])\n",
    "#     self.features_12 = torch.nn.Sequential(*list(model.features.children())[36:39])\n",
    "#     self.features_13 = torch.nn.Sequential(*list(model.features.children())[40:43])\n",
    "#     self.features_14 = torch.nn.Sequential(*list(model.features.children())[43:46])\n",
    "#     self.features_15 = torch.nn.Sequential(*list(model.features.children())[46:49])\n",
    "#     self.features_16 = torch.nn.Sequential(torch.nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),\n",
    "#                                            torch.nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),\n",
    "#                                            torch.nn.ReLU(inplace=True))\n",
    "    \n",
    "\n",
    "#     self.avgpool = nn.AdaptiveAvgPool2d(7)\n",
    "\n",
    "#     # classifier\n",
    "#     self.classifier = torch.nn.Sequential(\n",
    "#         torch.nn.Linear(25088, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 2)\n",
    "#     )\n",
    "\n",
    "\n",
    "#   # forward\n",
    "#   def forward(self, x):\n",
    "#     x = self.features_1(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_2(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_3(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_4(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_5(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_6(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_7(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_8(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_9(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_10(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_11(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_12(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_13(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_14(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_15(x)\n",
    "    \n",
    "    \n",
    "    \n",
    "#     x = self.features_16(x)\n",
    "#     x_conv = self.attention_x(x)\n",
    "#     #w_xy = self.soft_attention(x)\n",
    "#     y_conv = self.attention_y(x)\n",
    "#     temp_conv = self.attention_xy(x_conv + y_conv)\n",
    "#     x = x_conv + (x_conv * temp_conv)\n",
    "#     #b, c, h, w = x.size()\n",
    "#     #scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "#     #scale = self.spatial_attention(scale)\n",
    "#     #temp = torch.nn.functional.softmax(scale,dim=1)\n",
    "#     #temp1 = torch.nn.functional.softmax(scale,dim=0)\n",
    "#     #scale = temp + temp1\n",
    "#     #x = x + (x * scale)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.avgpool(x)\n",
    "#     x = x.view(x.shape[0], -1)\n",
    "#     x = self.classifier(x)\n",
    "#     return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.126152Z",
     "iopub.status.busy": "2023-09-20T21:58:15.126016Z",
     "iopub.status.idle": "2023-09-20T21:58:15.130145Z",
     "shell.execute_reply": "2023-09-20T21:58:15.129720Z"
    },
    "id": "gI_b0IkmDpz2"
   },
   "outputs": [],
   "source": [
    "#modified attention 2 \n",
    "class VGG19_CBAM(torch.nn.Module):\n",
    "\n",
    "  # init function\n",
    "  def __init__(self, model, num_classes=2):\n",
    "    super().__init__()\n",
    "\n",
    "    # pool layer\n",
    "    self.pool = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "\n",
    "    #self.beta = torch.nn.Parameter(torch.tensor([0.0]))\n",
    "\n",
    "    # channel attention\n",
    "    self.attention_x = torch.nn.Sequential(\n",
    "        torch.nn.Conv2d(512, 512, kernel_size=1, stride=1),\n",
    "        nn.PReLU(),\n",
    "        torch.nn.BatchNorm2d(512),\n",
    "        #torch.nn.Sigmoid()\n",
    "    )\n",
    "    self.attention_y = torch.nn.Sequential(\n",
    "        torch.nn.Conv2d(512, 512, kernel_size=1, stride=1),\n",
    "        nn.PReLU(),\n",
    "        torch.nn.BatchNorm2d(512),\n",
    "        #torch.nn.Sigmoid()\n",
    "    )\n",
    "    self.attention_xy = torch.nn.Sequential(\n",
    "        torch.nn.Conv2d(512, 512, kernel_size=1, stride=1),\n",
    "        #nn.PReLU(),\n",
    "        torch.nn.BatchNorm2d(512),\n",
    "        torch.nn.Sigmoid()\n",
    "    )\n",
    "    # spatial attention\n",
    "    self.spatial_attention = torch.nn.Sequential(\n",
    "        torch.nn.Conv2d(2, 1, kernel_size=3, padding=1, stride=1),\n",
    "        #torch.nn.PReLU(),\n",
    "        torch.nn.BatchNorm2d(1),\n",
    "        torch.nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "    # channel attention\n",
    "    # self.max_pool_1 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=224, stride=224))\n",
    "    # self.max_pool_2 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=112, stride=112))\n",
    "    # self.max_pool_3 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=56, stride=56))\n",
    "    # self.max_pool_4 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=28, stride=28))\n",
    "    # self.max_pool_5 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=14, stride=14))\n",
    "    # self.avg_pool_1 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=224, stride=224))\n",
    "    # self.avg_pool_2 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=112, stride=112))\n",
    "    # self.avg_pool_3 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=56, stride=56))\n",
    "    # self.avg_pool_4 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=28, stride=28))\n",
    "    # self.avg_pool_5 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=14, stride=14))\n",
    "\n",
    "    # features without considering vgg19 pooling layer\n",
    "    self.features_1 = torch.nn.Sequential(*list(model.features.children())[:3])\n",
    "    self.features_2 = torch.nn.Sequential(*list(model.features.children())[3:6])\n",
    "    self.features_3 = torch.nn.Sequential(*list(model.features.children())[7:10])\n",
    "    self.features_4 = torch.nn.Sequential(*list(model.features.children())[10:13])\n",
    "    self.features_5 = torch.nn.Sequential(*list(model.features.children())[14:17])\n",
    "    self.features_6 = torch.nn.Sequential(*list(model.features.children())[17:20])\n",
    "    self.features_7 = torch.nn.Sequential(*list(model.features.children())[20:23])\n",
    "    self.features_8 = torch.nn.Sequential(*list(model.features.children())[23:26])\n",
    "    self.features_9 = torch.nn.Sequential(*list(model.features.children())[27:30])\n",
    "    self.features_10 = torch.nn.Sequential(*list(model.features.children())[30:33])\n",
    "    self.features_11 = torch.nn.Sequential(*list(model.features.children())[33:36])\n",
    "    self.features_12 = torch.nn.Sequential(*list(model.features.children())[36:39])\n",
    "    self.features_13 = torch.nn.Sequential(*list(model.features.children())[40:43])\n",
    "    self.features_14 = torch.nn.Sequential(*list(model.features.children())[43:46])\n",
    "    self.features_15 = torch.nn.Sequential(*list(model.features.children())[46:49])\n",
    "    self.features_16 = torch.nn.Sequential(torch.nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),\n",
    "                                           torch.nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),\n",
    "                                           torch.nn.ReLU(inplace=True))\n",
    "    \n",
    "\n",
    "    self.avgpool = nn.AdaptiveAvgPool2d(7)\n",
    "\n",
    "    # classifier\n",
    "    self.classifier = torch.nn.Sequential(\n",
    "        torch.nn.Linear(25088, 4096),\n",
    "        torch.nn.ReLU(inplace=True),\n",
    "        torch.nn.Dropout(),\n",
    "        torch.nn.Linear(4096, 4096),\n",
    "        torch.nn.ReLU(inplace=True),\n",
    "        torch.nn.Dropout(),\n",
    "        torch.nn.Linear(4096, 2)\n",
    "    )\n",
    "\n",
    "\n",
    "  # forward\n",
    "  def forward(self, x):\n",
    "    x = self.features_1(x)\n",
    "    \n",
    "\n",
    "    x = self.features_2(x)\n",
    "    \n",
    "    x = self.pool(x)\n",
    "\n",
    "    x = self.features_3(x)\n",
    "    \n",
    "\n",
    "    x = self.features_4(x)\n",
    "    \n",
    "    x = self.pool(x)\n",
    "\n",
    "    x = self.features_5(x)\n",
    "    \n",
    "\n",
    "    x = self.features_6(x)\n",
    "    \n",
    "\n",
    "    x = self.features_7(x)\n",
    "    \n",
    "\n",
    "    x = self.features_8(x)\n",
    "    \n",
    "    x = self.pool(x)\n",
    "\n",
    "    x = self.features_9(x)\n",
    "    \n",
    "\n",
    "    x = self.features_10(x)\n",
    "    \n",
    "\n",
    "    x = self.features_11(x)\n",
    "    \n",
    "\n",
    "    x = self.features_12(x)\n",
    "    \n",
    "    x = self.pool(x)\n",
    "\n",
    "    x = self.features_13(x)\n",
    "    \n",
    "\n",
    "    x = self.features_14(x)\n",
    "    \n",
    "\n",
    "    x = self.features_15(x)\n",
    "    \n",
    "    \n",
    "    \n",
    "    x = self.features_16(x)\n",
    "    x_conv = self.attention_x(x)\n",
    "    #w_xy = self.soft_attention(x)\n",
    "    y_conv = self.attention_y(x)\n",
    "    temp_conv = self.attention_xy(x_conv + y_conv)\n",
    "    x = x_conv + (x_conv * temp_conv)\n",
    "    b, c, h, w = x.size()\n",
    "    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "    scale = self.spatial_attention(scale)\n",
    "    #temp = torch.nn.functional.softmax(scale,dim=1)\n",
    "    #temp1 = torch.nn.functional.softmax(scale,dim=0)\n",
    "    #scale = temp + temp1\n",
    "    scale = torch.exp(scale)\n",
    "    #print(scale.shape)\n",
    "    temp = torch.sum(scale,(2,3))\n",
    "    #print(temp.shape)\n",
    "    scale = scale / temp.unsqueeze(1).unsqueeze(1)\n",
    "    x = x + (x * scale)\n",
    "    \n",
    "    x = self.pool(x)\n",
    "\n",
    "    x = self.avgpool(x)\n",
    "    x = x.view(x.shape[0], -1)\n",
    "    x = self.classifier(x)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.131704Z",
     "iopub.status.busy": "2023-09-20T21:58:15.131569Z",
     "iopub.status.idle": "2023-09-20T21:58:15.135258Z",
     "shell.execute_reply": "2023-09-20T21:58:15.134829Z"
    },
    "id": "UbrDkxW5XNEB"
   },
   "outputs": [],
   "source": [
    "# #soft attention\n",
    "# class VGG19_CBAM(torch.nn.Module):\n",
    "\n",
    "#   # init function\n",
    "#   def __init__(self, model, num_classes=2):\n",
    "#     super().__init__()\n",
    "\n",
    "#     # pool layer\n",
    "#     self.pool = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "\n",
    "#     # spatial attention\n",
    "#     self.soft_attention = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(512, 512, kernel_size=1, stride=1),\n",
    "#         #nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(512),\n",
    "#         #torch.nn.Sigmoid()\n",
    "#     )\n",
    "\n",
    "#     # channel attention\n",
    "#     self.max_pool_1 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=224, stride=224))\n",
    "#     self.max_pool_2 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=112, stride=112))\n",
    "#     self.max_pool_3 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=56, stride=56))\n",
    "#     self.max_pool_4 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=28, stride=28))\n",
    "#     self.max_pool_5 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=14, stride=14))\n",
    "#     self.avg_pool_1 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=224, stride=224))\n",
    "#     self.avg_pool_2 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=112, stride=112))\n",
    "#     self.avg_pool_3 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=56, stride=56))\n",
    "#     self.avg_pool_4 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=28, stride=28))\n",
    "#     self.avg_pool_5 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=14, stride=14))\n",
    "\n",
    "#     # features without considering vgg19 pooling layer\n",
    "#     self.features_1 = torch.nn.Sequential(*list(model.features.children())[:3])\n",
    "#     self.features_2 = torch.nn.Sequential(*list(model.features.children())[3:6])\n",
    "#     self.features_3 = torch.nn.Sequential(*list(model.features.children())[7:10])\n",
    "#     self.features_4 = torch.nn.Sequential(*list(model.features.children())[10:13])\n",
    "#     self.features_5 = torch.nn.Sequential(*list(model.features.children())[14:17])\n",
    "#     self.features_6 = torch.nn.Sequential(*list(model.features.children())[17:20])\n",
    "#     self.features_7 = torch.nn.Sequential(*list(model.features.children())[20:23])\n",
    "#     self.features_8 = torch.nn.Sequential(*list(model.features.children())[23:26])\n",
    "#     self.features_9 = torch.nn.Sequential(*list(model.features.children())[27:30])\n",
    "#     self.features_10 = torch.nn.Sequential(*list(model.features.children())[30:33])\n",
    "#     self.features_11 = torch.nn.Sequential(*list(model.features.children())[33:36])\n",
    "#     self.features_12 = torch.nn.Sequential(*list(model.features.children())[36:39])\n",
    "#     self.features_13 = torch.nn.Sequential(*list(model.features.children())[40:43])\n",
    "#     self.features_14 = torch.nn.Sequential(*list(model.features.children())[43:46])\n",
    "#     self.features_15 = torch.nn.Sequential(*list(model.features.children())[46:49])\n",
    "#     self.features_16 = torch.nn.Sequential(torch.nn.Conv2d(512*2, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),\n",
    "#                                            torch.nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),\n",
    "#                                            torch.nn.ReLU(inplace=True))\n",
    "    \n",
    "\n",
    "#     self.avgpool = nn.AdaptiveAvgPool2d(7)\n",
    "\n",
    "#     # classifier\n",
    "#     self.classifier = torch.nn.Sequential(\n",
    "#         torch.nn.Linear(25088, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 2)\n",
    "#     )\n",
    "\n",
    "\n",
    "#   # forward\n",
    "#   def forward(self, x):\n",
    "#     x = self.features_1(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_2(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_3(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_4(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_5(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_6(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_7(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_8(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_9(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_10(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_11(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_12(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_13(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_14(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_15(x)\n",
    "#     x_conv = self.soft_attention(x)\n",
    "#     w_xy = self.soft_attention(x)\n",
    "#     y_conv = self.soft_attention(w_xy)\n",
    "#     temp_conv = torch.nn.functional.sigmoid(self.soft_attention(x_conv + y_conv))\n",
    "#     temp = w_xy + (w_xy * temp_conv)\n",
    "#     x = torch.cat((x,temp), dim =1)\n",
    "    \n",
    "#     x = self.features_16(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.avgpool(x)\n",
    "#     x = x.view(x.shape[0], -1)\n",
    "#     x = self.classifier(x)\n",
    "#     return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.136820Z",
     "iopub.status.busy": "2023-09-20T21:58:15.136686Z",
     "iopub.status.idle": "2023-09-20T21:58:15.140176Z",
     "shell.execute_reply": "2023-09-20T21:58:15.139743Z"
    },
    "id": "C7RcEiolmJP4"
   },
   "outputs": [],
   "source": [
    "# #soft attention modified\n",
    "# class VGG19_CBAM(torch.nn.Module):\n",
    "\n",
    "#   # init function\n",
    "#   def __init__(self, model, num_classes=2):\n",
    "#     super().__init__()\n",
    "\n",
    "#     # pool layer\n",
    "#     self.pool = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "\n",
    "#     # spatial attention\n",
    "#     self.soft_attention = torch.nn.Sequential(\n",
    "#         torch.nn.Conv2d(512, 512, kernel_size=1, stride=1),\n",
    "#         #nn.PReLU(),\n",
    "#         torch.nn.BatchNorm2d(512),\n",
    "#         #torch.nn.Sigmoid()\n",
    "#     )\n",
    "\n",
    "#     # channel attention\n",
    "#     self.max_pool_1 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=224, stride=224))\n",
    "#     self.max_pool_2 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=112, stride=112))\n",
    "#     self.max_pool_3 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=56, stride=56))\n",
    "#     self.max_pool_4 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=28, stride=28))\n",
    "#     self.max_pool_5 = torch.nn.Sequential(torch.nn.MaxPool2d(kernel_size=14, stride=14))\n",
    "#     self.avg_pool_1 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=224, stride=224))\n",
    "#     self.avg_pool_2 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=112, stride=112))\n",
    "#     self.avg_pool_3 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=56, stride=56))\n",
    "#     self.avg_pool_4 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=28, stride=28))\n",
    "#     self.avg_pool_5 = torch.nn.Sequential(torch.nn.AvgPool2d(kernel_size=14, stride=14))\n",
    "\n",
    "#     # features without considering vgg19 pooling layer\n",
    "#     self.features_1 = torch.nn.Sequential(*list(model.features.children())[:3])\n",
    "#     self.features_2 = torch.nn.Sequential(*list(model.features.children())[3:6])\n",
    "#     self.features_3 = torch.nn.Sequential(*list(model.features.children())[7:10])\n",
    "#     self.features_4 = torch.nn.Sequential(*list(model.features.children())[10:13])\n",
    "#     self.features_5 = torch.nn.Sequential(*list(model.features.children())[14:17])\n",
    "#     self.features_6 = torch.nn.Sequential(*list(model.features.children())[17:20])\n",
    "#     self.features_7 = torch.nn.Sequential(*list(model.features.children())[20:23])\n",
    "#     self.features_8 = torch.nn.Sequential(*list(model.features.children())[23:26])\n",
    "#     self.features_9 = torch.nn.Sequential(*list(model.features.children())[27:30])\n",
    "#     self.features_10 = torch.nn.Sequential(*list(model.features.children())[30:33])\n",
    "#     self.features_11 = torch.nn.Sequential(*list(model.features.children())[33:36])\n",
    "#     self.features_12 = torch.nn.Sequential(*list(model.features.children())[36:39])\n",
    "#     self.features_13 = torch.nn.Sequential(*list(model.features.children())[40:43])\n",
    "#     self.features_14 = torch.nn.Sequential(*list(model.features.children())[43:46])\n",
    "#     self.features_15 = torch.nn.Sequential(*list(model.features.children())[46:49])\n",
    "#     self.features_16 = torch.nn.Sequential(torch.nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),\n",
    "#                                            torch.nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),\n",
    "#                                            torch.nn.ReLU(inplace=True))\n",
    "    \n",
    "\n",
    "#     self.avgpool = nn.AdaptiveAvgPool2d(7)\n",
    "\n",
    "#     # classifier\n",
    "#     self.classifier = torch.nn.Sequential(\n",
    "#         torch.nn.Linear(25088*2, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 4096),\n",
    "#         torch.nn.ReLU(inplace=True),\n",
    "#         torch.nn.Dropout(),\n",
    "#         torch.nn.Linear(4096, 2)\n",
    "#     )\n",
    "\n",
    "\n",
    "#   # forward\n",
    "#   def forward(self, x):\n",
    "#     x = self.features_1(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_2(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_3(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_4(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_5(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_6(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_7(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_8(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_9(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_10(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_11(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_12(x)\n",
    "    \n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.features_13(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_14(x)\n",
    "    \n",
    "\n",
    "#     x = self.features_15(x)\n",
    "    \n",
    "    \n",
    "#     x = self.features_16(x)\n",
    "#     x_conv = self.soft_attention(x)\n",
    "#     w_xy = self.soft_attention(x)\n",
    "#     y_conv = self.soft_attention(w_xy)\n",
    "#     temp_conv = torch.nn.functional.sigmoid(self.soft_attention(x_conv + y_conv))\n",
    "#     temp = w_xy + (w_xy * temp_conv)\n",
    "#     x = torch.cat((x,temp), dim =1)\n",
    "#     x = self.pool(x)\n",
    "\n",
    "#     x = self.avgpool(x)\n",
    "#     x = x.view(x.shape[0], -1)\n",
    "#     x = self.classifier(x)\n",
    "#     return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.141805Z",
     "iopub.status.busy": "2023-09-20T21:58:15.141670Z",
     "iopub.status.idle": "2023-09-20T21:58:15.169733Z",
     "shell.execute_reply": "2023-09-20T21:58:15.169289Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import os\n",
    "import sys\n",
    "\n",
    "class SeparableConv2d(nn.Module):\n",
    "  def __init__(self, c_in, c_out, ks, stride=1, padding=0, dilation=1, bias=False):\n",
    "    super(SeparableConv2d, self).__init__()\n",
    "    self.c = nn.Conv2d(c_in, c_in, ks, stride, padding, dilation, groups=c_in, bias=bias)\n",
    "    self.pointwise = nn.Conv2d(c_in, c_out, 1, 1, 0, 1, 1, bias=bias)\n",
    "\n",
    "  def forward(self, x):\n",
    "    x = self.c(x)\n",
    "    x = self.pointwise(x)\n",
    "    return x\n",
    "\n",
    "class Block(nn.Module):\n",
    "  def __init__(self, c_in, c_out, reps, stride=1, start_with_relu=True, grow_first=True):\n",
    "    super(Block, self).__init__()\n",
    "    \n",
    "    self.skip = None\n",
    "    self.skip_bn = None\n",
    "    if c_out != c_in or stride!= 1:\n",
    "      self.skip = nn.Conv2d(c_in, c_out, 1, stride=stride, bias=False)\n",
    "      self.skip_bn = nn.BatchNorm2d(c_out)\n",
    "\n",
    "    self.relu = nn.ReLU(inplace=True)\n",
    "    \n",
    "    rep = []\n",
    "    c = c_in\n",
    "    if grow_first:\n",
    "      rep.append(self.relu)\n",
    "      rep.append(SeparableConv2d(c_in, c_out, 3, stride=1, padding=1, bias=False))\n",
    "      rep.append(nn.BatchNorm2d(c_out))\n",
    "      c = c_out\n",
    "    \n",
    "    for i in range(reps - 1):\n",
    "      rep.append(self.relu)\n",
    "      rep.append(SeparableConv2d(c, c, 3, stride=1, padding=1, bias=False))\n",
    "      rep.append(nn.BatchNorm2d(c))\n",
    "\n",
    "    if not grow_first:\n",
    "      rep.append(self.relu)\n",
    "      rep.append(SeparableConv2d(c_in, c_out, 3, stride=1, padding=1, bias=False))\n",
    "      rep.append(nn.BatchNorm2d(c_out))\n",
    "    \n",
    "    if not start_with_relu:\n",
    "      rep = rep[1:]\n",
    "    else:\n",
    "      rep[0] = nn.ReLU(inplace=False)\n",
    "\n",
    "    if stride != 1:\n",
    "      rep.append(nn.MaxPool2d(3, stride, 1))\n",
    "    self.rep = nn.Sequential(*rep)\n",
    "\n",
    "  def forward(self, inp):\n",
    "    x = self.rep(inp)\n",
    "    \n",
    "    if self.skip is not None:\n",
    "      y = self.skip(inp)\n",
    "      y = self.skip_bn(y)\n",
    "    else:\n",
    "      y = inp\n",
    "    \n",
    "    x += y\n",
    "    return x\n",
    "\n",
    "class RegressionMap(nn.Module):\n",
    "  def __init__(self, c_in):\n",
    "    super(RegressionMap, self).__init__()\n",
    "    self.c = SeparableConv2d(c_in, 1, 3, stride=1, padding=1, bias=False)\n",
    "    self.s = nn.Sigmoid()\n",
    "\n",
    "  def forward(self, x):\n",
    "    mask = self.c(x)\n",
    "    mask = self.s(mask)\n",
    "    return mask, None\n",
    "\n",
    "class MyAttention(torch.nn.Module):\n",
    "\n",
    "  # init function\n",
    "  def __init__(self):\n",
    "    super(MyAttention, self).__init__()\n",
    "\n",
    "    # channel attention\n",
    "    self.attention_x = torch.nn.Sequential(\n",
    "        torch.nn.Conv2d(728, 728, kernel_size=1, stride=1),\n",
    "        nn.PReLU(),\n",
    "        torch.nn.BatchNorm2d(728),\n",
    "        #torch.nn.Sigmoid()\n",
    "    )\n",
    "    self.attention_y = torch.nn.Sequential(\n",
    "        torch.nn.Conv2d(728, 728, kernel_size=1, stride=1),\n",
    "        nn.PReLU(),\n",
    "        torch.nn.BatchNorm2d(728),\n",
    "        #torch.nn.Sigmoid()\n",
    "    )\n",
    "    self.attention_xy = torch.nn.Sequential(\n",
    "        torch.nn.Conv2d(728, 728, kernel_size=1, stride=1),\n",
    "        #nn.PReLU(),\n",
    "        torch.nn.BatchNorm2d(728),\n",
    "        torch.nn.Sigmoid()\n",
    "    )\n",
    "    # spatial attention\n",
    "    self.spatial_attention = torch.nn.Sequential(\n",
    "        torch.nn.Conv2d(2, 1, kernel_size=3, padding=1, stride=1),\n",
    "        #torch.nn.PReLU(),\n",
    "        torch.nn.BatchNorm2d(1),\n",
    "        torch.nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "  # forward\n",
    "  def forward(self, x):\n",
    "    x_conv = self.attention_x(x)\n",
    "    #w_xy = self.soft_attention(x)\n",
    "    y_conv = self.attention_y(x)\n",
    "    temp_conv = self.attention_xy(x_conv + y_conv)\n",
    "    x = x_conv + (x_conv * temp_conv)\n",
    "    b, c, h, w = x.size()\n",
    "    scale = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)\n",
    "    scale = self.spatial_attention(scale)\n",
    "    #temp = torch.nn.functional.softmax(scale,dim=1)\n",
    "    #temp1 = torch.nn.functional.softmax(scale,dim=0)\n",
    "    #scale = temp + temp1\n",
    "    scale = torch.exp(scale)\n",
    "    #print(scale.shape)\n",
    "    temp = torch.sum(scale,(2,3))\n",
    "    #print(temp.shape)\n",
    "    scale = scale / temp.unsqueeze(1).unsqueeze(1)\n",
    "    x = x + (x * scale)\n",
    "    return x , None\n",
    "\n",
    "class EmptyMap(nn.Module):\n",
    "  def __init__(self):\n",
    "    super(EmptyMap, self).__init__()\n",
    "    # self.c = SeparableConv2d(c_in, 1, 3, stride=1, padding=1, bias=False)\n",
    "    # self.s = nn.Sigmoid()\n",
    "\n",
    "  def forward(self, x):\n",
    "    # mask = self.c(x)\n",
    "    # mask = self.s(mask)\n",
    "    return torch.ones(1).cuda(), None\n",
    "\n",
    "class TemplateMap(nn.Module):\n",
    "  def __init__(self, c_in, templates):\n",
    "    super(TemplateMap, self).__init__()\n",
    "    self.c = Block(c_in, 364, 2, 2, start_with_relu=True, grow_first=False)\n",
    "    self.l = nn.Linear(364, 10)\n",
    "    self.relu = nn.ReLU(inplace=True)\n",
    "    \n",
    "    self.templates = templates\n",
    "\n",
    "  def forward(self, x):\n",
    "    v = self.c(x)\n",
    "    v = self.relu(v)\n",
    "    v = F.adaptive_avg_pool2d(v, (1,1))\n",
    "    v = v.view(v.size(0), -1)\n",
    "    v = self.l(v)\n",
    "    mask = torch.mm(v, self.templates.reshape(10,361))\n",
    "    mask = mask.reshape(x.shape[0], 1, 19, 19)\n",
    "\n",
    "    return mask, v\n",
    "\n",
    "class PCATemplateMap(nn.Module):\n",
    "  def __init__(self, templates):\n",
    "    super(PCATemplateMap, self).__init__()\n",
    "    self.templates = templates\n",
    "\n",
    "  def forward(self, x):\n",
    "    fe = x.view(x.shape[0], x.shape[1], x.shape[2]*x.shape[3])\n",
    "    fe = torch.transpose(fe, 1, 2)\n",
    "    mu = torch.mean(fe, 2, keepdim=True)\n",
    "    fea_diff = fe - mu\n",
    "    \n",
    "    cov_fea = torch.bmm(fea_diff, torch.transpose(fea_diff, 1, 2))\n",
    "    B = self.templates.reshape(1, 10, 361).repeat(x.shape[0], 1, 1)\n",
    "    D = torch.bmm(torch.bmm(B, cov_fea), torch.transpose(B, 1, 2))\n",
    "    eigen_value, eigen_vector = D.symeig(eigenvectors=True)\n",
    "    index = torch.tensor([9]).cuda()\n",
    "    eigen = torch.index_select(eigen_vector, 2, index)\n",
    "\n",
    "    v = eigen.squeeze(-1)\n",
    "    mask = torch.mm(v, self.templates.reshape(10, 361))\n",
    "    mask = mask.reshape(x.shape[0], 1, 19, 19)\n",
    "    return mask, v\n",
    "\n",
    "class Xception(nn.Module):\n",
    "  \"\"\"\n",
    "  Xception optimized for the ImageNet dataset, as specified in\n",
    "  https://arxiv.org/pdf/1610.02357.pdf\n",
    "  \"\"\"\n",
    "  def __init__(self, maptype, templates, num_classes=1000):\n",
    "    super(Xception, self).__init__()\n",
    "    self.num_classes = num_classes\n",
    "\n",
    "    self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False)\n",
    "    self.bn1 = nn.BatchNorm2d(32)\n",
    "    self.relu = nn.ReLU(inplace=True)\n",
    "\n",
    "    self.conv2 = nn.Conv2d(32,64,3,bias=False)\n",
    "    self.bn2 = nn.BatchNorm2d(64)\n",
    "\n",
    "    self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)\n",
    "    self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)\n",
    "    self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)\n",
    "    self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True)\n",
    "    self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True)\n",
    "    self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True)\n",
    "    self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True)\n",
    "    self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True)\n",
    "    self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True)\n",
    "    self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True)\n",
    "    self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True)\n",
    "    self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)\n",
    "\n",
    "    self.conv3 = SeparableConv2d(1024,1536,3,1,1)\n",
    "    self.bn3 = nn.BatchNorm2d(1536)\n",
    "\n",
    "    self.conv4 = SeparableConv2d(1536,2048,3,1,1)\n",
    "    self.bn4 = nn.BatchNorm2d(2048)\n",
    "    \n",
    "    self.last_linear = nn.Linear(2048, num_classes)\n",
    "    \n",
    "    if maptype == 'none':\n",
    "      self.map = EmptyMap()\n",
    "    elif maptype == 'reg':\n",
    "      self.map = RegressionMap(728)\n",
    "    elif maptype == 'tmp':\n",
    "      self.map = TemplateMap(728, templates)\n",
    "    elif maptype == 'pca_tmp':\n",
    "      self.map = PCATemplateMap(728)\n",
    "    elif maptype == 'myatt':\n",
    "      self.map = MyAttention()\n",
    "    else:\n",
    "      print('Unknown map type: `{0}`'.format(maptype))\n",
    "      sys.exit()\n",
    "\n",
    "  def features(self, input):\n",
    "    x = self.conv1(input)\n",
    "    x = self.bn1(x)\n",
    "    x = self.relu(x)\n",
    "\n",
    "    x = self.conv2(x)\n",
    "    x = self.bn2(x)\n",
    "    x = self.relu(x)\n",
    "\n",
    "    x = self.block1(x)\n",
    "    x = self.block2(x)\n",
    "    x = self.block3(x)\n",
    "    x = self.block4(x)\n",
    "    x = self.block5(x)\n",
    "    x = self.block6(x)\n",
    "    x = self.block7(x)\n",
    "    mask, vec = self.map(x)\n",
    "    x = x * mask\n",
    "    x = self.block8(x)\n",
    "    x = self.block9(x)\n",
    "    x = self.block10(x)\n",
    "    x = self.block11(x)\n",
    "    x = self.block12(x)\n",
    "    x = self.conv3(x)\n",
    "    x = self.bn3(x)\n",
    "    x = self.relu(x)\n",
    "\n",
    "    x = self.conv4(x)\n",
    "    x = self.bn4(x)\n",
    "    return x, mask, vec\n",
    "\n",
    "  def logits(self, features):\n",
    "    x = self.relu(features)\n",
    "    x = F.adaptive_avg_pool2d(x, (1, 1))\n",
    "    x = x.view(x.size(0), -1)\n",
    "    x = self.last_linear(x)\n",
    "    return x\n",
    "\n",
    "  def forward(self, input):\n",
    "    x, mask, vec = self.features(input)\n",
    "    x = self.logits(x)\n",
    "    return x, mask, vec\n",
    "\n",
    "def init_weights(m):\n",
    "  classname = m.__class__.__name__\n",
    "  if classname.find('SeparableConv2d') != -1:\n",
    "    m.c.weight.data.normal_(0.0, 0.01)\n",
    "    if m.c.bias is not None:\n",
    "      m.c.bias.data.fill_(0)\n",
    "    m.pointwise.weight.data.normal_(0.0, 0.01)\n",
    "    if m.pointwise.bias is not None:\n",
    "      m.pointwise.bias.data.fill_(0)\n",
    "  elif classname.find('Conv') != -1 or classname.find('Linear') != -1:\n",
    "    m.weight.data.normal_(0.0, 0.01)\n",
    "    if m.bias is not None:\n",
    "      m.bias.data.fill_(0)\n",
    "  elif classname.find('BatchNorm') != -1:\n",
    "    m.weight.data.normal_(1.0, 0.01)\n",
    "    m.bias.data.fill_(0)\n",
    "  elif classname.find('LSTM') != -1:\n",
    "    for i in m._parameters:\n",
    "      if i.__class__.__name__.find('weight') != -1:\n",
    "        i.data.normal_(0.0, 0.01)\n",
    "      elif i.__class__.__name__.find('bias') != -1:\n",
    "        i.bias.data.fill_(0)\n",
    "\n",
    "class Model:\n",
    "  def __init__(self, maptype='None', templates=None, num_classes=2, load_pretrain=True):\n",
    "    model = Xception(maptype, templates, num_classes=num_classes)\n",
    "    if load_pretrain:\n",
    "      state_dict = torch.load('./xception-b5690688.pth')\n",
    "      for name, weights in state_dict:\n",
    "        if 'pointwise' in name:\n",
    "          state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)\n",
    "      del state_dict['fc.weight']\n",
    "      del state_dict['fc.bias']\n",
    "      model.load_state_dict(state_dict, False)\n",
    "    else:\n",
    "      model.apply(init_weights)\n",
    "    self.model = model\n",
    "\n",
    "  def save(self, epoch, optim, model_dir):\n",
    "    state = {'net': self.model.state_dict(), 'optim': optim.state_dict()}\n",
    "    torch.save(state, '{0}/{1:06d}.tar'.format(model_dir, epoch))\n",
    "    print('Saved model `{0}`'.format(epoch))\n",
    "\n",
    "  def load(self, epoch, model_dir):\n",
    "    filename = '{0}{1:06d}.tar'.format(model_dir, epoch)\n",
    "    print('Loading model from {0}'.format(filename))\n",
    "    if os.path.exists(filename):\n",
    "      state = torch.load(filename)\n",
    "      self.model.load_state_dict(state['net'])\n",
    "    else:\n",
    "      print('Failed to load model from {0}'.format(filename))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.171438Z",
     "iopub.status.busy": "2023-09-20T21:58:15.171272Z",
     "iopub.status.idle": "2023-09-20T21:58:15.429717Z",
     "shell.execute_reply": "2023-09-20T21:58:15.429252Z"
    }
   },
   "outputs": [],
   "source": [
    "MODEL = Model('reg', num_classes=2, load_pretrain=False)\n",
    "net = MODEL.model\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.431467Z",
     "iopub.status.busy": "2023-09-20T21:58:15.431299Z",
     "iopub.status.idle": "2023-09-20T21:58:15.433825Z",
     "shell.execute_reply": "2023-09-20T21:58:15.433288Z"
    },
    "id": "-0qLREVMICdd",
    "outputId": "ebe4a6b5-509d-49c6-a2a4-5b520943904a"
   },
   "outputs": [],
   "source": [
    "pretrained_model = torchvision.models.vgg19_bn(pretrained=True)\n",
    "# net = torchvision.models.vgg19_bn(pretrained=True)\n",
    "# net = torchvision.models.resnet50(pretrained=True)\n",
    "net = VGG19_CBAM(pretrained_model, 2)\n",
    "#model.to(device)\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.435353Z",
     "iopub.status.busy": "2023-09-20T21:58:15.435190Z",
     "iopub.status.idle": "2023-09-20T21:58:15.437741Z",
     "shell.execute_reply": "2023-09-20T21:58:15.437364Z"
    },
    "id": "pj8jdNGfOE5l"
   },
   "outputs": [],
   "source": [
    "# net.classifier[6] = nn.Linear(in_features=4096, out_features=2, bias=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.439253Z",
     "iopub.status.busy": "2023-09-20T21:58:15.439091Z",
     "iopub.status.idle": "2023-09-20T21:58:15.441058Z",
     "shell.execute_reply": "2023-09-20T21:58:15.440720Z"
    },
    "id": "LVmvibHSyrum"
   },
   "outputs": [],
   "source": [
    "# net.fc = torch.nn.Sequential(\n",
    "#     torch.nn.Linear(2048, 2, bias=True)\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.442531Z",
     "iopub.status.busy": "2023-09-20T21:58:15.442373Z",
     "iopub.status.idle": "2023-09-20T21:58:15.444407Z",
     "shell.execute_reply": "2023-09-20T21:58:15.443978Z"
    },
    "id": "tY6CEEt9OE8c",
    "outputId": "be16bbf8-c2f8-4446-9848-677e25248fab"
   },
   "outputs": [],
   "source": [
    "# net.classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.445930Z",
     "iopub.status.busy": "2023-09-20T21:58:15.445755Z",
     "iopub.status.idle": "2023-09-20T21:58:15.448119Z",
     "shell.execute_reply": "2023-09-20T21:58:15.447661Z"
    },
    "id": "qQZfJBjAOE_C"
   },
   "outputs": [],
   "source": [
    "# def make_fine_tunable(model):\n",
    "#     for param in model.parameters():\n",
    "#         param.requires_grad = True\n",
    "#     for param in model.classifier.parameters():\n",
    "#         param.requires_grad = True\n",
    "#     print(\"Tunable Layers: \")\n",
    "#     for (name, param) in model.named_parameters():\n",
    "#         if param.requires_grad:\n",
    "#             print(f'{name} -> {param.requires_grad}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.449647Z",
     "iopub.status.busy": "2023-09-20T21:58:15.449478Z",
     "iopub.status.idle": "2023-09-20T21:58:15.451557Z",
     "shell.execute_reply": "2023-09-20T21:58:15.451223Z"
    },
    "id": "8OnTVWuU_QfT"
   },
   "outputs": [],
   "source": [
    "# def make_fine_tunable(model):\n",
    "#     for param in model.parameters():\n",
    "#         param.requires_grad = True\n",
    "#     for param in model.fc.parameters():\n",
    "#         param.requires_grad = True\n",
    "#     print(\"Tunable Layers: \")\n",
    "#     for (name, param) in model.named_parameters():\n",
    "#         if param.requires_grad:\n",
    "#             print(f'{name} -> {param.requires_grad}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.453086Z",
     "iopub.status.busy": "2023-09-20T21:58:15.452918Z",
     "iopub.status.idle": "2023-09-20T21:58:15.455727Z",
     "shell.execute_reply": "2023-09-20T21:58:15.455290Z"
    }
   },
   "outputs": [],
   "source": [
    "def make_fine_tunable(model):\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad = True\n",
    "    # for param in model.classifier.parameters():\n",
    "    #     param.requires_grad = True\n",
    "    print(\"Tunable Layers: \")\n",
    "    for (name, param) in model.named_parameters():\n",
    "        if param.requires_grad:\n",
    "            print(f'{name} -> {param.requires_grad}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.457263Z",
     "iopub.status.busy": "2023-09-20T21:58:15.457103Z",
     "iopub.status.idle": "2023-09-20T21:58:15.460614Z",
     "shell.execute_reply": "2023-09-20T21:58:15.460225Z"
    },
    "id": "9wN4wCW3OFBp",
    "outputId": "53a21295-a1e1-4b6d-a968-96b07a8694f1"
   },
   "outputs": [],
   "source": [
    "make_fine_tunable(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:15.462059Z",
     "iopub.status.busy": "2023-09-20T21:58:15.461923Z",
     "iopub.status.idle": "2023-09-20T21:58:17.026913Z",
     "shell.execute_reply": "2023-09-20T21:58:17.026308Z"
    },
    "id": "7xmPgaPCOFEQ"
   },
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=1e-3, momentum=0.9 , weight_decay=0.0001)\n",
    "#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.99)\n",
    "net = net.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:17.029186Z",
     "iopub.status.busy": "2023-09-20T21:58:17.028989Z",
     "iopub.status.idle": "2023-09-20T21:58:17.032480Z",
     "shell.execute_reply": "2023-09-20T21:58:17.032027Z"
    },
    "id": "GkFltnzqxvRL",
    "outputId": "780c90a7-1536-4efb-d23b-c5ee22336404"
   },
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()\n",
    "best_val_acc = -1000\n",
    "best_val_model = None\n",
    "train_losses = []\n",
    "train_acc = []\n",
    "val_losses = []\n",
    "val_acc = []\n",
    "test_losses = []\n",
    "test_acc = []\n",
    "\n",
    "for epoch in range(30):  \n",
    "    net.train()\n",
    "    running_loss = 0.0\n",
    "    running_acc = 0\n",
    "    for i, data in enumerate(train_loader, 0):\n",
    "        inputs, labels = data\n",
    "        inputs, labels = inputs.cuda(),(labels.type(torch.LongTensor)).cuda()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # print statistics\n",
    "        running_loss += loss.item()\n",
    "        out = torch.argmax(outputs.detach(),dim=1)\n",
    "        assert out.shape==labels.shape\n",
    "        running_acc += (labels==out).sum().item()\n",
    "    print(f\"Train loss {epoch+1}: {running_loss/len(train_data)},Train Acc:{running_acc*100/len(train_data)}%\")\n",
    "    training_loss = running_loss/len(train_data)\n",
    "    training_accuracy = running_acc*100/len(train_data)\n",
    "    train_losses.append(training_loss)\n",
    "    train_acc.append(training_accuracy)\n",
    "    \n",
    "    correct = 0\n",
    "    net.eval()\n",
    "    valid_loss = 0.0\n",
    "    with torch.no_grad():\n",
    "        torch.cuda.empty_cache()\n",
    "        for inputs,labels in val_loader:\n",
    "            out = net(inputs.cuda())\n",
    "            loss = criterion(out, (labels.type(torch.LongTensor)).cuda())\n",
    "            out = torch.argmax(out.cpu(),dim=1)\n",
    "            acc = (out==labels.cpu()).sum().item()\n",
    "            correct += acc\n",
    "            #loss = criterion(outputs, labels)\n",
    "            valid_loss = valid_loss + loss.item()\n",
    "    #print(f\"Val loss {epoch+1}: {val_loss},Val accuracy:{correct*100/len(val_data)}%\")\n",
    "    valid_loss = valid_loss / float(len(val_data))\n",
    "    valid_accuracy = correct*100/len(val_data)\n",
    "    val_losses.append(valid_loss)\n",
    "    val_acc.append(valid_accuracy)\n",
    "    if correct>best_val_acc:\n",
    "        best_val_acc = correct\n",
    "        best_val_model = deepcopy(net.state_dict())\n",
    "    #lr_scheduler.step()\n",
    "    print(f\"Val loss {epoch+1}: {valid_loss},Val accuracy:{correct*100/len(val_data)}%\")\n",
    "\n",
    "    test_correct = 0\n",
    "    test_loss = 0.0\n",
    "    preds = []\n",
    "    #net.load_state_dict(best_val_model)\n",
    "    net.eval()\n",
    "    with torch.no_grad():\n",
    "      torch.cuda.empty_cache()\n",
    "      for i,data in enumerate(test_loader, 0):\n",
    "          test_im, targets = data\n",
    "          test_im, targets = test_im.cuda(), (targets.type(torch.LongTensor)).cuda()\n",
    "          out = net(test_im)\n",
    "          loss = criterion(out, targets)\n",
    "          out = torch.argmax(out,dim=1)\n",
    "          acc = (out==targets).sum().item()\n",
    "          preds.append(out.detach().cpu())\n",
    "          test_correct += acc\n",
    "          test_loss = test_loss + loss.item()\n",
    "      test_loss = test_loss/len(test_data)\n",
    "      test_accuracy = test_correct*100/len(test_data)\n",
    "      test_losses.append(test_loss)\n",
    "      test_acc.append(test_accuracy)\n",
    "\n",
    "    preds = torch.cat(preds, dim=0).numpy()\n",
    "    print(f\"Test loss {epoch+1}: {test_loss},Test accuracy: {test_correct*100/len(test_data)}%\")\n",
    "    print(\"\\n\")\n",
    "    #scheduler.step()\n",
    "    # if (epoch + 1)%2 == 0:\n",
    "    #   scheduler.step()\n",
    "print('Finished Training')\n",
    "torch.save(best_val_model, \"weights_model_id_swap_attn.pth\")  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-20T21:58:17.034091Z",
     "iopub.status.busy": "2023-09-20T21:58:17.033931Z",
     "iopub.status.idle": "2023-09-21T00:38:53.416755Z",
     "shell.execute_reply": "2023-09-21T00:38:53.416163Z"
    }
   },
   "outputs": [],
   "source": [
    "# # xception\n",
    "# torch.cuda.empty_cache()\n",
    "# best_val_acc = -1000\n",
    "# best_val_model = None\n",
    "# train_losses = []\n",
    "# train_acc = []\n",
    "# val_losses = []\n",
    "# val_acc = []\n",
    "# test_losses = []\n",
    "# test_acc = []\n",
    "\n",
    "# for epoch in range(60):  \n",
    "#     net.train()\n",
    "#     running_loss = 0.0\n",
    "#     running_acc = 0\n",
    "#     for i, data in enumerate(train_loader, 0):\n",
    "#         inputs, labels = data\n",
    "#         inputs, labels = inputs.cuda(),(labels.type(torch.LongTensor)).cuda()\n",
    "#         # print(\"Checking here\")\n",
    "#         optimizer.zero_grad()\n",
    "#         outputs,_,_ = net(inputs)\n",
    "#         loss = criterion(outputs, labels)\n",
    "#         # print(\"Checking if this is working\")\n",
    "#         loss.backward()\n",
    "#         optimizer.step()\n",
    "\n",
    "#         # print statistics\n",
    "#         running_loss += loss.item()\n",
    "#         out = torch.argmax(outputs.detach(),dim=1)\n",
    "#         assert out.shape==labels.shape\n",
    "#         running_acc += (labels==out).sum().item()\n",
    "#     print(f\"Train loss {epoch+1}: {running_loss/len(train_data)},Train Acc:{running_acc*100/len(train_data)}%\")\n",
    "#     training_loss = running_loss/len(train_data)\n",
    "#     training_accuracy = running_acc*100/len(train_data)\n",
    "#     train_losses.append(training_loss)\n",
    "#     train_acc.append(training_accuracy)\n",
    "    \n",
    "#     correct = 0\n",
    "#     net.eval()\n",
    "#     valid_loss = 0.0\n",
    "#     with torch.no_grad():\n",
    "#         torch.cuda.empty_cache()\n",
    "#         for inputs,labels in val_loader:\n",
    "#             out,_,_ = net(inputs.cuda())\n",
    "#             loss = criterion(out, (labels.type(torch.LongTensor)).cuda())\n",
    "#             out = torch.argmax(out.cpu(),dim=1)\n",
    "#             acc = (out==labels.cpu()).sum().item()\n",
    "#             correct += acc\n",
    "#             #loss = criterion(outputs, labels)\n",
    "#             valid_loss = valid_loss + loss.item()\n",
    "#     #print(f\"Val loss {epoch+1}: {val_loss},Val accuracy:{correct*100/len(val_data)}%\")\n",
    "#     valid_loss = valid_loss / float(len(val_data))\n",
    "#     valid_accuracy = correct*100/len(val_data)\n",
    "#     val_losses.append(valid_loss)\n",
    "#     val_acc.append(valid_accuracy)\n",
    "#     if correct>best_val_acc:\n",
    "#         best_val_acc = correct\n",
    "#         best_val_model = deepcopy(net.state_dict())\n",
    "#     #lr_scheduler.step()\n",
    "#     print(f\"Val loss {epoch+1}: {valid_loss},Val accuracy:{correct*100/len(val_data)}%\")\n",
    "\n",
    "#     test_correct = 0\n",
    "#     test_loss = 0.0\n",
    "#     preds = []\n",
    "#     #net.load_state_dict(best_val_model)\n",
    "#     net.eval()\n",
    "#     with torch.no_grad():\n",
    "#       torch.cuda.empty_cache()\n",
    "#       for i,data in enumerate(test_loader, 0):\n",
    "#           test_im, targets = data\n",
    "#           test_im, targets = test_im.cuda(), (targets.type(torch.LongTensor)).cuda()\n",
    "#           out,_,_ = net(test_im)\n",
    "#           loss = criterion(out, targets)\n",
    "#           out = torch.argmax(out,dim=1)\n",
    "#           acc = (out==targets).sum().item()\n",
    "#           preds.append(out.detach().cpu())\n",
    "#           test_correct += acc\n",
    "#           test_loss = test_loss + loss.item()\n",
    "#       test_loss = test_loss/len(test_data)\n",
    "#       test_accuracy = test_correct*100/len(test_data)\n",
    "#       test_losses.append(test_loss)\n",
    "#       test_acc.append(test_accuracy)\n",
    "\n",
    "#     preds = torch.cat(preds, dim=0).numpy()\n",
    "#     print(f\"Test loss {epoch+1}: {test_loss},Test accuracy: {test_correct*100/len(test_data)}%\")\n",
    "#     print(\"\\n\")\n",
    "#     #scheduler.step()\n",
    "#     # if (epoch + 1)%2 == 0:\n",
    "#     #   scheduler.step()\n",
    "# print('Finished Training')\n",
    "# torch.save(best_val_model, \"weights_model_id_swap_xception_reg.pth\")  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 968
    },
    "execution": {
     "iopub.execute_input": "2023-09-21T00:38:53.418810Z",
     "iopub.status.busy": "2023-09-21T00:38:53.418630Z",
     "iopub.status.idle": "2023-09-21T00:38:53.677405Z",
     "shell.execute_reply": "2023-09-21T00:38:53.676870Z"
    },
    "id": "OmqCQXmROFJv",
    "outputId": "d9fd10ae-4167-4385-b519-26c32a54d0d2"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "e = []\n",
    "for index in range(0, 60):\n",
    "    e.append(index)\n",
    "plt.figure(figsize=(20, 10))\n",
    "plt.subplot(1,2,1)    \n",
    "plt.plot(e, train_losses, color='r', label='loss')\n",
    "plt.title('Loss')\n",
    "plt.subplot(1,2,2)  \n",
    "plt.plot(e, train_acc, color='b', label='acc')\n",
    "plt.title('Accuracy')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 532
    },
    "execution": {
     "iopub.execute_input": "2023-09-21T00:38:53.679184Z",
     "iopub.status.busy": "2023-09-21T00:38:53.679021Z",
     "iopub.status.idle": "2023-09-21T00:38:53.914966Z",
     "shell.execute_reply": "2023-09-21T00:38:53.914477Z"
    },
    "id": "E7wAtI2lOFL2",
    "outputId": "213eb00b-2be0-49ae-be73-744c57dc9d82"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "plt.subplot(1,2,1) \n",
    "plt.plot(e, val_losses, color='r', label='loss')\n",
    "plt.title('Loss')\n",
    "plt.subplot(1,2,2)\n",
    "plt.plot(e, val_acc, color='b', label='acc')\n",
    "plt.title('Accuracy')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 532
    },
    "execution": {
     "iopub.execute_input": "2023-09-21T00:38:53.916725Z",
     "iopub.status.busy": "2023-09-21T00:38:53.916560Z",
     "iopub.status.idle": "2023-09-21T00:38:54.225020Z",
     "shell.execute_reply": "2023-09-21T00:38:54.224548Z"
    },
    "id": "EHyoRQZrOmpw",
    "outputId": "6c1269b6-b3dc-4137-dbed-d1a994b3c735"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "plt.subplot(1,2,1)\n",
    "plt.plot(e, test_losses, color='r', label='loss')\n",
    "plt.title('Loss')\n",
    "plt.subplot(1,2,2)\n",
    "plt.plot(e, test_acc, color='b', label='acc')\n",
    "plt.title('Accuracy')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2023-09-21T00:38:54.226696Z",
     "iopub.status.busy": "2023-09-21T00:38:54.226529Z",
     "iopub.status.idle": "2023-09-21T00:38:54.229121Z",
     "shell.execute_reply": "2023-09-21T00:38:54.228687Z"
    },
    "id": "JSN9qhfmOmsm",
    "outputId": "8839c5d9-de62-479c-c050-8152ec115620"
   },
   "outputs": [],
   "source": [
    "correct = 0\n",
    "preds = []\n",
    "lbls = []\n",
    "pred_lbls = []\n",
    "net.load_state_dict(best_val_model)\n",
    "net.eval()\n",
    "with torch.no_grad():\n",
    "    for i,data in enumerate(test_loader, 0):\n",
    "        test_im, targets = data\n",
    "        test_im, targets = test_im.cuda(), targets.type(torch.LongTensor).cuda()\n",
    "        out1 = net(test_im)\n",
    "        values,_ = torch.max(torch.softmax(out1, dim=1),dim=1)\n",
    "        out = torch.argmax(out1,dim=1)\n",
    "        acc = (out==targets).sum().item()\n",
    "        lbls.append(targets.detach().cpu())\n",
    "        preds.append(values.detach().cpu())\n",
    "        pred_lbls.append(out.detach().cpu())\n",
    "        # preds.append(out.detach().cpu())\n",
    "        correct += acc\n",
    "\n",
    "preds = torch.cat(preds, dim=0).numpy()\n",
    "pred_lbls = torch.cat(pred_lbls, dim=0).numpy()\n",
    "lbls = torch.cat(lbls, dim=0).numpy()\n",
    "print(f\"Test accuracy: {correct*100/len(test_data)}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:38:54.230608Z",
     "iopub.status.busy": "2023-09-21T00:38:54.230442Z",
     "iopub.status.idle": "2023-09-21T00:39:27.371523Z",
     "shell.execute_reply": "2023-09-21T00:39:27.371021Z"
    }
   },
   "outputs": [],
   "source": [
    "# correct = 0\n",
    "# preds = []\n",
    "# lbls = []\n",
    "# pred_lbls = []\n",
    "# net.load_state_dict(best_val_model)\n",
    "# # net1.load_state_dict(best_val_model_net1)\n",
    "# # grnet.load_state_dict(best_val_model_grnet)\n",
    "# # /home/ashit/Downloads/best_model_all_attn_sft.pth\n",
    "# #net.load_state_dict(\"/home/ashit/Downloads/best_model_all_attn_sft.pth\")\n",
    "# net.eval()\n",
    "# # net1.eval()\n",
    "# with torch.no_grad():\n",
    "#     for i,data in enumerate(test_loader, 0):\n",
    "#         test_im, targets = data\n",
    "#         test_im, targets = test_im.cuda(), targets.type(torch.LongTensor).cuda()\n",
    "#         out1,_,_ = net(test_im)\n",
    "#         out = torch.argmax(out1,dim=1)\n",
    "#         values,_ = torch.max(torch.softmax(out1, dim=1),dim=1)\n",
    "#         acc = (out==targets).sum().item()\n",
    "#         lbls.append(targets.detach().cpu())\n",
    "#         preds.append(values.detach().cpu())\n",
    "#         pred_lbls.append(out.detach().cpu())\n",
    "#         correct += acc\n",
    "\n",
    "# preds = torch.cat(preds, dim=0).numpy()\n",
    "# pred_lbls = torch.cat(pred_lbls, dim=0).numpy()\n",
    "# lbls = torch.cat(lbls, dim=0).numpy()\n",
    "# print(f\"Test accuracy: {correct*100/len(test_data)}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.373493Z",
     "iopub.status.busy": "2023-09-21T00:39:27.373347Z",
     "iopub.status.idle": "2023-09-21T00:39:27.375699Z",
     "shell.execute_reply": "2023-09-21T00:39:27.375264Z"
    },
    "id": "J1okuH8wavwb",
    "outputId": "cc66212d-2154-414d-86bb-c818c31d4b08"
   },
   "outputs": [],
   "source": [
    "# for i, data in enumerate(train_loader, 0):\n",
    "#         inputs, labels = data\n",
    "#         print(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.377307Z",
     "iopub.status.busy": "2023-09-21T00:39:27.377157Z",
     "iopub.status.idle": "2023-09-21T00:39:27.379174Z",
     "shell.execute_reply": "2023-09-21T00:39:27.378741Z"
    },
    "id": "uCJufh2LoxSW"
   },
   "outputs": [],
   "source": [
    "# m = nn.Softmax(dim=1)\n",
    "# input = torch.randn(2, 3)\n",
    "# output = m(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.380641Z",
     "iopub.status.busy": "2023-09-21T00:39:27.380509Z",
     "iopub.status.idle": "2023-09-21T00:39:27.382525Z",
     "shell.execute_reply": "2023-09-21T00:39:27.382200Z"
    },
    "id": "-ItqfrPvpg5d",
    "outputId": "c570ed03-8e2b-4bdf-808f-143d074c0212"
   },
   "outputs": [],
   "source": [
    "# output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.383975Z",
     "iopub.status.busy": "2023-09-21T00:39:27.383826Z",
     "iopub.status.idle": "2023-09-21T00:39:27.567088Z",
     "shell.execute_reply": "2023-09-21T00:39:27.566636Z"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_auc_score\n",
    "# preds = np.array([0, 0, 0, 1, 0, 1, 0, 1, 1, 1])\n",
    "# lbls = np.array([0, 0, 0, 0, 1, 0, 1, 1, 1, 1])\n",
    "preds[pred_lbls == 0] = 1 - preds[pred_lbls == 0]\n",
    "roc_auc_score(lbls, preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from torch.utils.data import ConcatDataset\n",
    "# celeb_real = CustomDataSet(\"/home/ashit/Downloads/Celebdf-realimages\",modify_transforms)\n",
    "# celeb_fake = CustomDataSet(\"/home/ashit/Downloads/Celebdf-fakeimages\",modify_transforms)\n",
    "# test_data = ConcatDataset([celeb_real,celeb_fake])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from torch.utils.data import ConcatDataset\n",
    "# dfdc_real = CustomDataSet(\"/home/ashit/Downloads/DFDC-realimages\",modify_transforms)\n",
    "# dfdc_fake = CustomDataSet(\"/home/ashit/Downloads/DFDC-fakeimages\",modify_transforms)\n",
    "# test_data = ConcatDataset([dfdc_real,dfdc_fake])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from torch.utils.data import DataLoader\n",
    "\n",
    "# image_data_loader = DataLoader(\n",
    "#   test_data, \n",
    "#   # batch size is whole datset\n",
    "#   batch_size=len(test_data), \n",
    "#   shuffle=False, \n",
    "#   num_workers=0)\n",
    "\n",
    "# def mean_std(loader):\n",
    "#   images, lebels = next(iter(loader))\n",
    "#   # shape of images = [b,c,w,h]\n",
    "#   mean, std = images.mean([0,2,3]), images.std([0,2,3])\n",
    "#   return mean, std\n",
    "\n",
    "# mean, std = mean_std(image_data_loader)\n",
    "# print(\"mean and std: \\n\", mean, std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #celeb-db\n",
    "# modify_transforms = transforms.Compose([transforms.RandomRotation(30),\n",
    "#                                        transforms.Resize((224, 224)),\n",
    "#                                        transforms.RandomHorizontalFlip(),\n",
    "#                                        transforms.Normalize([118.1276,  78.5629,  69.1345], [61.4242, 44.1643, 41.8714])])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# modify_transforms = transforms.Compose([transforms.RandomRotation(30),\n",
    "#                                        transforms.Resize((224, 224)),\n",
    "#                                        transforms.RandomHorizontalFlip(),\n",
    "#                                        transforms.Normalize([94.5180, 76.2486, 64.1614],[55.1707, 52.8614, 47.1455])])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test_loader = torch.utils.data.DataLoader(faceswap_test, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# correct = 0\n",
    "# preds = []\n",
    "# lbls = []\n",
    "# pred_lbls = []\n",
    "# model_path = '/home/ashit/Downloads/weights_model_id_swap_attn.pth'\n",
    "# model = torch.load(model_path)\n",
    "# net.load_state_dict(model)\n",
    "# net.eval()\n",
    "# with torch.no_grad():\n",
    "#     for i,data in enumerate(test_loader, 0):\n",
    "#         test_im, targets = data\n",
    "#         test_im, targets = test_im.cuda(), targets.type(torch.LongTensor).cuda()\n",
    "#         out1 = net(test_im)\n",
    "#         values,_ = torch.max(torch.softmax(out1, dim=1),dim=1)\n",
    "#         out = torch.argmax(out1,dim=1)\n",
    "#         acc = (out==targets).sum().item()\n",
    "#         lbls.append(targets.detach().cpu())\n",
    "#         preds.append(values.detach().cpu())\n",
    "#         pred_lbls.append(out.detach().cpu())\n",
    "#         # preds.append(out.detach().cpu())\n",
    "#         correct += acc\n",
    "\n",
    "# preds = torch.cat(preds, dim=0).numpy()\n",
    "# pred_lbls = torch.cat(pred_lbls, dim=0).numpy()\n",
    "# lbls = torch.cat(lbls, dim=0).numpy()\n",
    "# print(f\"Test accuracy: {correct*100/len(deepfakes_test)}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# correct = 0\n",
    "# preds = []\n",
    "# lbls = []\n",
    "# pred_lbls = []\n",
    "# model_path = '/home/ashit/Downloads/weights_model_id_swap_xception_reg.pth'\n",
    "# model = torch.load(model_path)\n",
    "# net.load_state_dict(model)\n",
    "# net.eval()\n",
    "# # net1.eval()\n",
    "# with torch.no_grad():\n",
    "#     for i,data in enumerate(test_loader, 0):\n",
    "#         test_im, targets = data\n",
    "#         test_im, targets = test_im.cuda(), targets.type(torch.LongTensor).cuda()\n",
    "#         out1,_,_ = net(test_im)\n",
    "#         out = torch.argmax(out1,dim=1)\n",
    "#         values,_ = torch.max(torch.softmax(out1, dim=1),dim=1)\n",
    "#         acc = (out==targets).sum().item()\n",
    "#         lbls.append(targets.detach().cpu())\n",
    "#         preds.append(values.detach().cpu())\n",
    "#         pred_lbls.append(out.detach().cpu())\n",
    "#         correct += acc\n",
    "\n",
    "# preds = torch.cat(preds, dim=0).numpy()\n",
    "# pred_lbls = torch.cat(pred_lbls, dim=0).numpy()\n",
    "# lbls = torch.cat(lbls, dim=0).numpy()\n",
    "# print(f\"Test accuracy: {correct*100/len(test_data)}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from sklearn.metrics import roc_auc_score\n",
    "# # preds = np.array([0, 0, 0, 1, 0, 1, 0, 1, 1, 1])\n",
    "# # lbls = np.array([0, 0, 0, 0, 1, 0, 1, 1, 1, 1])\n",
    "# preds[pred_lbls == 0] = 1 - preds[pred_lbls == 0]\n",
    "# roc_auc_score(lbls, preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.568890Z",
     "iopub.status.busy": "2023-09-21T00:39:27.568723Z",
     "iopub.status.idle": "2023-09-21T00:39:27.571682Z",
     "shell.execute_reply": "2023-09-21T00:39:27.571363Z"
    }
   },
   "outputs": [],
   "source": [
    "# import torch\n",
    "# import torch.nn as nn\n",
    "# from torchvision import transforms\n",
    "# from PIL import Image\n",
    "\n",
    "# # Step 1: Load the pre-trained model\n",
    "# model_path = '/home/ashit/Downloads/weights_model_id_swap_attn.pth'\n",
    "# model = torch.load(model_path)\n",
    "# net.load_state_dict(model)\n",
    "# net.eval()  # Set the model to evaluation mode\n",
    "\n",
    "# # Step 2: Define image preprocessing transforms\n",
    "# # transform = transforms.Compose([\n",
    "# #     transforms.Resize((224, 224)),  # Resize image to match model input size\n",
    "# #     transforms.ToTensor(),  # Convert image to PyTorch tensor\n",
    "# #     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize image\n",
    "# # ])\n",
    "# # modify_transforms = transforms.Compose([transforms.RandomRotation(30),\n",
    "# #                                        transforms.Resize((224, 224)),\n",
    "# #                                        transforms.ToTensor(),\n",
    "# #                                        transforms.RandomHorizontalFlip(),\n",
    "# #                                        transforms.Normalize([117.1626, 104.0307, 108.8566] , [67.3590, 57.2767, 63.3579]),\n",
    "# #                                        AddGaussianNoise(0.0 , 0.1)])\n",
    "\n",
    "# modify_transforms = transforms.Compose([transforms.RandomRotation(30),\n",
    "#                                        transforms.Resize((224, 224)),\n",
    "#                                        transforms.ToTensor(),\n",
    "#                                        transforms.RandomHorizontalFlip(),\n",
    "#                                        transforms.Normalize([ 93.4020,  95.5818, 113.5519] , [55.9949, 53.7595, 64.2225]),\n",
    "#                                        AddGaussianNoise(0.0 , 0.1)])\n",
    "# # Step 3: Load and preprocess the sample image\n",
    "# image_path = '/home/ashit/Downloads/FaceSwap/manipulated_sequences/FaceSwap/c23/images/022_489_0004.png'\n",
    "# image = Image.open(image_path)\n",
    "# input_tensor = modify_transforms(image)\n",
    "# input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension\n",
    "\n",
    "# # Step 4: Perform inference\n",
    "# with torch.no_grad():\n",
    "#     output = net(input_tensor.cuda())\n",
    "\n",
    "# # Step 5: Post-process the output if needed\n",
    "# # For classification, you can use softmax to get probabilities\n",
    "# softmax = nn.Softmax(dim=1)\n",
    "# probabilities = softmax(output)\n",
    "# _, predicted_class = torch.max(probabilities, 1)\n",
    "\n",
    "# # You can now work with 'predicted_class' (the predicted class index) and 'probabilities' (class probabilities).\n",
    "# # Print the results:\n",
    "# print(\"Predicted class index:\", predicted_class.item())\n",
    "# print(\"Class probabilities:\", probabilities[0].tolist())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using Lime with Pytorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.573310Z",
     "iopub.status.busy": "2023-09-21T00:39:27.573145Z",
     "iopub.status.idle": "2023-09-21T00:39:27.575271Z",
     "shell.execute_reply": "2023-09-21T00:39:27.574851Z"
    }
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import os, json\n",
    "\n",
    "import torch\n",
    "from torchvision import models, transforms\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.576779Z",
     "iopub.status.busy": "2023-09-21T00:39:27.576623Z",
     "iopub.status.idle": "2023-09-21T00:39:27.578937Z",
     "shell.execute_reply": "2023-09-21T00:39:27.578448Z"
    }
   },
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.580405Z",
     "iopub.status.busy": "2023-09-21T00:39:27.580252Z",
     "iopub.status.idle": "2023-09-21T00:39:27.582472Z",
     "shell.execute_reply": "2023-09-21T00:39:27.582045Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_image(path):\n",
    "    with open(os.path.abspath(path), 'rb') as f:\n",
    "        with Image.open(f) as img:\n",
    "            return img.convert('RGB')\n",
    "        \n",
    "img = get_image('/home/ashit/Downloads/DeepFakeDetection/manipulated_sequences/DeepFakeDetection/c23/images/01_02__exit_phone_room__YVGY8LOK_0003.png')\n",
    "plt.imshow(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.583924Z",
     "iopub.status.busy": "2023-09-21T00:39:27.583759Z",
     "iopub.status.idle": "2023-09-21T00:39:27.586141Z",
     "shell.execute_reply": "2023-09-21T00:39:27.585698Z"
    }
   },
   "outputs": [],
   "source": [
    "# resize and take the center part of image to what our model expects\n",
    "def get_input_transform():\n",
    "    normalize = transforms.Normalize([ 91.6383,  99.6159, 130.1860] , [53.6807, 54.1669, 66.2160])      \n",
    "    transf = transforms.Compose([\n",
    "        transforms.Resize((256, 256)),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        normalize\n",
    "    ])    \n",
    "\n",
    "    return transf\n",
    "\n",
    "def get_input_tensors(img):\n",
    "    transf = get_input_transform()\n",
    "    # unsqeeze converts single image to batch of 1\n",
    "    return transf(img).unsqueeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.587659Z",
     "iopub.status.busy": "2023-09-21T00:39:27.587501Z",
     "iopub.status.idle": "2023-09-21T00:39:27.589670Z",
     "shell.execute_reply": "2023-09-21T00:39:27.589225Z"
    }
   },
   "outputs": [],
   "source": [
    "model_path = '/home/ashit/Downloads/weights_model_id_swap_attn.pth'\n",
    "# model = torchvision.models.vgg19_bn(pretrained=True)\n",
    "model = net\n",
    "net = torch.load(model_path)\n",
    "model.load_state_dict(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.591317Z",
     "iopub.status.busy": "2023-09-21T00:39:27.591168Z",
     "iopub.status.idle": "2023-09-21T00:39:27.593169Z",
     "shell.execute_reply": "2023-09-21T00:39:27.592734Z"
    }
   },
   "outputs": [],
   "source": [
    "device = 'cuda'\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.594828Z",
     "iopub.status.busy": "2023-09-21T00:39:27.594673Z",
     "iopub.status.idle": "2023-09-21T00:39:27.596716Z",
     "shell.execute_reply": "2023-09-21T00:39:27.596387Z"
    }
   },
   "outputs": [],
   "source": [
    "# idx2label, cls2label, cls2idx = [], {}, {}\n",
    "# with open(os.path.abspath('./data/imagenet_class_index.json'), 'r') as read_file:\n",
    "#     class_idx = json.load(read_file)\n",
    "#     idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]\n",
    "#     cls2label = {class_idx[str(k)][0]: class_idx[str(k)][1] for k in range(len(class_idx))}\n",
    "#     cls2idx = {class_idx[str(k)][0]: k for k in range(len(class_idx))}  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.598233Z",
     "iopub.status.busy": "2023-09-21T00:39:27.598085Z",
     "iopub.status.idle": "2023-09-21T00:39:27.600256Z",
     "shell.execute_reply": "2023-09-21T00:39:27.599832Z"
    }
   },
   "outputs": [],
   "source": [
    "img_t = get_input_tensors(img)\n",
    "# model.eval()\n",
    "# logits = model(img_t)\n",
    "with torch.no_grad():\n",
    "    output = model(img_t.cuda())\n",
    "softmax = nn.Softmax(dim=1)\n",
    "probabilities = softmax(output)\n",
    "_, predicted_class = torch.max(probabilities, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.601735Z",
     "iopub.status.busy": "2023-09-21T00:39:27.601580Z",
     "iopub.status.idle": "2023-09-21T00:39:27.603592Z",
     "shell.execute_reply": "2023-09-21T00:39:27.603166Z"
    }
   },
   "outputs": [],
   "source": [
    "print(\"Predicted class index:\", predicted_class.item())\n",
    "print(\"Class probabilities:\", probabilities[0].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.605106Z",
     "iopub.status.busy": "2023-09-21T00:39:27.604947Z",
     "iopub.status.idle": "2023-09-21T00:39:27.606940Z",
     "shell.execute_reply": "2023-09-21T00:39:27.606517Z"
    }
   },
   "outputs": [],
   "source": [
    "# probs = F.softmax(logits, dim=1)\n",
    "# probs5 = probs.topk(5)\n",
    "# tuple((p,c, idx2label[c]) for p, c in zip(probs5[0][0].detach().numpy(), probs5[1][0].detach().numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.608393Z",
     "iopub.status.busy": "2023-09-21T00:39:27.608263Z",
     "iopub.status.idle": "2023-09-21T00:39:27.610649Z",
     "shell.execute_reply": "2023-09-21T00:39:27.610211Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_pil_transform(): \n",
    "    transf = transforms.Compose([\n",
    "        transforms.Resize((256, 256)),\n",
    "        transforms.CenterCrop(224)\n",
    "    ])    \n",
    "\n",
    "    return transf\n",
    "\n",
    "def get_preprocess_transform():\n",
    "    # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "    #                                 std=[0.229, 0.224, 0.225]) \n",
    "    normalize = transforms.Normalize([ 91.6383,  99.6159, 130.1860] , [53.6807, 54.1669, 66.2160])    \n",
    "    transf = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        normalize\n",
    "    ])    \n",
    "\n",
    "    return transf    \n",
    "\n",
    "pill_transf = get_pil_transform()\n",
    "preprocess_transform = get_preprocess_transform()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.612109Z",
     "iopub.status.busy": "2023-09-21T00:39:27.611956Z",
     "iopub.status.idle": "2023-09-21T00:39:27.614077Z",
     "shell.execute_reply": "2023-09-21T00:39:27.613651Z"
    }
   },
   "outputs": [],
   "source": [
    "def batch_predict(images):\n",
    "    model.eval()\n",
    "    batch = torch.stack(tuple(preprocess_transform(i) for i in images), dim=0)\n",
    "\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    model.to(device)\n",
    "    batch = batch.to(device)\n",
    "    \n",
    "    logits = model(batch)\n",
    "    probs = F.softmax(logits, dim=1)\n",
    "    return probs.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.615498Z",
     "iopub.status.busy": "2023-09-21T00:39:27.615370Z",
     "iopub.status.idle": "2023-09-21T00:39:27.617348Z",
     "shell.execute_reply": "2023-09-21T00:39:27.616920Z"
    }
   },
   "outputs": [],
   "source": [
    "test_pred = batch_predict([pill_transf(img)])\n",
    "test_pred.squeeze().argmax()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.618786Z",
     "iopub.status.busy": "2023-09-21T00:39:27.618656Z",
     "iopub.status.idle": "2023-09-21T00:39:27.620768Z",
     "shell.execute_reply": "2023-09-21T00:39:27.620350Z"
    }
   },
   "outputs": [],
   "source": [
    "from lime import lime_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.622199Z",
     "iopub.status.busy": "2023-09-21T00:39:27.622069Z",
     "iopub.status.idle": "2023-09-21T00:39:27.624219Z",
     "shell.execute_reply": "2023-09-21T00:39:27.623784Z"
    }
   },
   "outputs": [],
   "source": [
    "explainer = lime_image.LimeImageExplainer()\n",
    "explanation = explainer.explain_instance(np.array(pill_transf(img)), \n",
    "                                         batch_predict, # classification function\n",
    "                                         top_labels=5, \n",
    "                                         hide_color=0, \n",
    "                                         num_samples=1000) # number of images that will be sent to classification fu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.625679Z",
     "iopub.status.busy": "2023-09-21T00:39:27.625528Z",
     "iopub.status.idle": "2023-09-21T00:39:27.627524Z",
     "shell.execute_reply": "2023-09-21T00:39:27.627100Z"
    }
   },
   "outputs": [],
   "source": [
    "from skimage.segmentation import mark_boundaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.629001Z",
     "iopub.status.busy": "2023-09-21T00:39:27.628871Z",
     "iopub.status.idle": "2023-09-21T00:39:27.630891Z",
     "shell.execute_reply": "2023-09-21T00:39:27.630467Z"
    }
   },
   "outputs": [],
   "source": [
    "temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False)\n",
    "img_boundry1 = mark_boundaries(temp/255.0, mask)\n",
    "plt.imshow(img_boundry1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-09-21T00:39:27.632309Z",
     "iopub.status.busy": "2023-09-21T00:39:27.632179Z",
     "iopub.status.idle": "2023-09-21T00:39:27.634185Z",
     "shell.execute_reply": "2023-09-21T00:39:27.633864Z"
    }
   },
   "outputs": [],
   "source": [
    "temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=8, hide_rest=False)\n",
    "img_boundry2 = mark_boundaries(temp/255.0, mask)\n",
    "plt.imshow(img_boundry2)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "machine_shape": "hm",
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "0495a64e954d449ebf01fbef0760deae": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_18effc414e0d44b3a4c6f57eddffb7e4",
      "max": 170498071,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_908e96e9eb2e4724b647736bbc0c5d22",
      "value": 170498071
     }
    },
    "16f4c71f484d4d6ca1dbd8a7196e402e": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "18effc414e0d44b3a4c6f57eddffb7e4": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "42dc2fb8a67d4162bdd5b5511f60930f": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "585c417cef39452389283f609169da66": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "908e96e9eb2e4724b647736bbc0c5d22": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "9dce0aa8753040f0ab3280ed88c533ce": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_a03c7fb2df3946b9807be2edf853a122",
       "IPY_MODEL_0495a64e954d449ebf01fbef0760deae",
       "IPY_MODEL_b9477599dbbd44a082c622c22f02efcd"
      ],
      "layout": "IPY_MODEL_42dc2fb8a67d4162bdd5b5511f60930f"
     }
    },
    "a03c7fb2df3946b9807be2edf853a122": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_585c417cef39452389283f609169da66",
      "placeholder": "​",
      "style": "IPY_MODEL_f9a4591273244e188943c4c224d912fe",
      "value": "100%"
     }
    },
    "b9477599dbbd44a082c622c22f02efcd": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_dc3362b414424680b73c1c2e8ac32308",
      "placeholder": "​",
      "style": "IPY_MODEL_16f4c71f484d4d6ca1dbd8a7196e402e",
      "value": " 170498071/170498071 [00:05&lt;00:00, 34420321.21it/s]"
     }
    },
    "dc3362b414424680b73c1c2e8ac32308": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f9a4591273244e188943c4c224d912fe": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
