{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "awZJxbqcXR8Q",
    "outputId": "005fdebe-0867-4f99-cba3-0709755cad90"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.4.1+cu121)\n",
      "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.19.1+cu121)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n",
      "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.3)\n",
      "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n",
      "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n",
      "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.26.4)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (10.4.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n",
      "Collecting git+https://github.com/MadryLab/robustness.git\n",
      "  Cloning https://github.com/MadryLab/robustness.git to /tmp/pip-req-build-l1p58t6z\n",
      "  Running command git clone --filter=blob:none --quiet https://github.com/MadryLab/robustness.git /tmp/pip-req-build-l1p58t6z\n",
      "  Resolved https://github.com/MadryLab/robustness.git to commit a9541241defd9972e9334bfcdb804f6aefe24dc7\n",
      "  Running command git submodule update --init --recursive -q\n",
      "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
      "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (4.66.5)\n",
      "Requirement already satisfied: grpcio in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (1.64.1)\n",
      "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (5.9.5)\n",
      "Requirement already satisfied: gitpython in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (3.1.43)\n",
      "Requirement already satisfied: py3nvml in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (0.2.7)\n",
      "Requirement already satisfied: cox in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (0.1.post3)\n",
      "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (1.5.2)\n",
      "Requirement already satisfied: seaborn in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (0.13.1)\n",
      "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (2.4.1+cu121)\n",
      "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (0.19.1+cu121)\n",
      "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (2.1.4)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (1.26.4)\n",
      "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (1.13.1)\n",
      "Requirement already satisfied: GPUtil in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (1.4.0)\n",
      "Requirement already satisfied: dill in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (0.3.9)\n",
      "Requirement already satisfied: tensorboardX in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (2.6.2.2)\n",
      "Requirement already satisfied: tables in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (3.8.0)\n",
      "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from robustness==1.2.1.post2) (3.7.1)\n",
      "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from gitpython->robustness==1.2.1.post2) (4.0.11)\n",
      "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->robustness==1.2.1.post2) (1.3.0)\n",
      "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->robustness==1.2.1.post2) (0.12.1)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->robustness==1.2.1.post2) (4.53.1)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->robustness==1.2.1.post2) (1.4.7)\n",
      "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->robustness==1.2.1.post2) (24.1)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->robustness==1.2.1.post2) (10.4.0)\n",
      "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->robustness==1.2.1.post2) (3.1.4)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->robustness==1.2.1.post2) (2.8.2)\n",
      "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->robustness==1.2.1.post2) (2024.2)\n",
      "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->robustness==1.2.1.post2) (2024.1)\n",
      "Requirement already satisfied: xmltodict in /usr/local/lib/python3.10/dist-packages (from py3nvml->robustness==1.2.1.post2) (0.13.0)\n",
      "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->robustness==1.2.1.post2) (1.4.2)\n",
      "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->robustness==1.2.1.post2) (3.5.0)\n",
      "Requirement already satisfied: cython>=0.29.21 in /usr/local/lib/python3.10/dist-packages (from tables->robustness==1.2.1.post2) (3.0.11)\n",
      "Requirement already satisfied: numexpr>=2.6.2 in /usr/local/lib/python3.10/dist-packages (from tables->robustness==1.2.1.post2) (2.10.1)\n",
      "Requirement already satisfied: blosc2~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from tables->robustness==1.2.1.post2) (2.0.0)\n",
      "Requirement already satisfied: py-cpuinfo in /usr/local/lib/python3.10/dist-packages (from tables->robustness==1.2.1.post2) (9.0.0)\n",
      "Requirement already satisfied: protobuf>=3.20 in /usr/local/lib/python3.10/dist-packages (from tensorboardX->robustness==1.2.1.post2) (3.20.3)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->robustness==1.2.1.post2) (3.16.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch->robustness==1.2.1.post2) (4.12.2)\n",
      "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->robustness==1.2.1.post2) (1.13.3)\n",
      "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->robustness==1.2.1.post2) (3.3)\n",
      "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->robustness==1.2.1.post2) (3.1.4)\n",
      "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->robustness==1.2.1.post2) (2024.6.1)\n",
      "Requirement already satisfied: msgpack in /usr/local/lib/python3.10/dist-packages (from blosc2~=2.0.0->tables->robustness==1.2.1.post2) (1.0.8)\n",
      "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->gitpython->robustness==1.2.1.post2) (5.0.1)\n",
      "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->robustness==1.2.1.post2) (1.16.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->robustness==1.2.1.post2) (2.1.5)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->robustness==1.2.1.post2) (1.3.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install torch torchvision\n",
    "!pip install git+https://github.com/MadryLab/robustness.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MCq3f2uHa_s-"
   },
   "outputs": [],
   "source": [
    "ATTACK_EPS = 0.5\n",
    "ATTACK_STEPSIZE = 0.1\n",
    "ATTACK_STEPS = 10\n",
    "NUM_WORKERS = 8\n",
    "BATCH_SIZE = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ab-_FXaX3hS7"
   },
   "outputs": [],
   "source": [
    "import torch as ch\n",
    "from robustness.datasets import CIFAR\n",
    "ds = CIFAR('/tmp')\n",
    "OUT_DIR = '/tmp/'\n",
    "NUM_WORKERS = 16\n",
    "BATCH_SIZE = 512"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "TDJ8oyhG3l9F",
    "outputId": "6140b0e0-9dff-4992-e007-d948809f4fa8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==> Preparing dataset cifar..\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 12, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
      "  warnings.warn(_create_warning_msg(\n"
     ]
    }
   ],
   "source": [
    "from robustness import model_utils, datasets, train, defaults\n",
    "from robustness.datasets import CIFAR\n",
    "import torch as ch\n",
    "\n",
    "# We use cox (http://github.com/MadryLab/cox) to log, store and analyze\n",
    "# results. Read more at https//cox.readthedocs.io.\n",
    "from cox.utils import Parameters\n",
    "import cox.store\n",
    "\n",
    "ds = CIFAR('/tmp/')\n",
    "m, _ = model_utils.make_and_restore_model(arch='resnet18', dataset=ds)\n",
    "train_loader, val_loader = ds.make_loaders(batch_size=BATCH_SIZE, workers=NUM_WORKERS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "28asTluD5gMH"
   },
   "outputs": [],
   "source": [
    "from robustness import model_utils, datasets\n",
    "import torch as ch\n",
    "\n",
    "# Load the ImageNet dataset (replace with the path to your ImageNet dataset)\n",
    "ds = datasets.ImageNet('/tmp/')\n",
    "!cp drive/MyDrive/Results/sanity/checkpoints/*.pt .\n",
    "# Load the adversarially trained ResNet50 model from MadryLab's repository\n",
    "# Ensure that you download the model file and provide its local path here\n",
    "\n",
    "cp_filename = 'imagenet_linf_8.pt' #'imagenet_l2_3_0.pt'\n",
    "m, _ = model_utils.make_and_restore_model(arch='resnet50', dataset=ds, resume_path=cp_filename)\n",
    "#m, _ = model_utils.make_and_restore_model(arch='resnet50', dataset=ds, resume_path='imagenet_linf_8.pt')\n",
    "\n",
    "# Create data loaders for ImageNet\n",
    "#train_loader, val_loader = ds.make_loaders(batch_size=128, workers=10)\n",
    "\n",
    "# Example of using the model for prediction\n",
    "m.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "XkmuMTXdRxee",
    "outputId": "3050cae1-03a7-4d14-c390-0932d534291a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Probability of class 862: 0.0219\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "# Load and preprocess the image\n",
    "def preprocess_image(image_path, device):\n",
    "    input_image = Image.open(image_path).convert('RGB')\n",
    "\n",
    "    preprocess = transforms.Compose([\n",
    "        transforms.Resize(256),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ])\n",
    "\n",
    "    input_tensor = preprocess(input_image)\n",
    "    input_batch = input_tensor.unsqueeze(0)  # Create a mini-batch as expected by the model\n",
    "\n",
    "    return input_batch.to(device)\n",
    "\n",
    "# Predict with the model and get probability for a specific class\n",
    "def predict_class_probability(model, input_batch, class_idx):\n",
    "    with torch.no_grad():\n",
    "        output = model(input_batch)\n",
    "\n",
    "    # If the output is a tuple, select the first element\n",
    "    if isinstance(output, tuple):\n",
    "        output = output[0]\n",
    "\n",
    "    # Apply softmax to get probabilities\n",
    "    probabilities = torch.nn.functional.softmax(output, dim=1)\n",
    "\n",
    "    # Extract the probability of the specific class\n",
    "    class_probability = probabilities[0, class_idx].item()\n",
    "\n",
    "    return class_probability\n",
    "\n",
    "# Example usage\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "image_path = 'oxpets_500/Abyssinian_1.jpg'  # Replace with your image path\n",
    "input_batch = preprocess_image(image_path, device)\n",
    "\n",
    "# Ensure model is on the same device\n",
    "m = m.to(device)\n",
    "m.eval()\n",
    "\n",
    "# Specify the class index you want to check (for example, class index 243)\n",
    "class_idx = 862\n",
    "class_probability = predict_class_probability(m, input_batch, class_idx)\n",
    "\n",
    "# Print the probability of the specific class\n",
    "print(f\"Probability of class {class_idx}: {class_probability:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "p-Dvy9-vTa8I"
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "import os\n",
    "import glob\n",
    "import numpy as np\n",
    "from collections import Counter\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import accuracy_score, classification_report\n",
    "import xgboost as xgb\n",
    "from imblearn.over_sampling import SMOTE, BorderlineSMOTE\n",
    "from sklearn.model_selection import RepeatedStratifiedKFold, GridSearchCV\n",
    "from scipy.stats import kurtosis, skew\n",
    "#from sklearn.metrics import plot_confusion_matrix\n",
    "\n",
    "import seaborn as sns\n",
    "import os\n",
    "import pickle\n",
    "import numpy as np\n",
    "import cv2\n",
    "import skimage.io\n",
    "from skimage.segmentation import quickshift, mark_boundaries\n",
    "from skimage.measure import regionprops\n",
    "import copy\n",
    "import random\n",
    "import sklearn\n",
    "import sklearn.metrics\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.linear_model import Ridge\n",
    "from skimage import filters\n",
    "import pandas as pd\n",
    "import warnings\n",
    "import pickle\n",
    "from scipy.stats import kendalltau\n",
    "import sys\n",
    "import scipy.stats as stats\n",
    "from scipy.stats import wilcoxon\n",
    "import itertools\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.linear_model import Lasso\n",
    "from functools import partial\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.metrics import roc_curve\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import time\n",
    "from sklearn.utils import resample\n",
    "from scipy.stats import norm, gaussian_kde\n",
    "from sklearn.neighbors import KernelDensity\n",
    "import csv\n",
    "\n",
    "import matplotlib.colors as mcolors\n",
    "from skimage.transform import resize\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import f1_score, make_scorer\n",
    "\n",
    "from sklearn.metrics import roc_curve, auc\n",
    "from imblearn.pipeline import Pipeline\n",
    "from imblearn.over_sampling import BorderlineSMOTE\n",
    "import shutil\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from imblearn.pipeline import Pipeline\n",
    "from imblearn.over_sampling import BorderlineSMOTE\n",
    "from imblearn.under_sampling import EditedNearestNeighbours\n",
    "import xgboost as xgb\n",
    "from sklearn.model_selection import GridSearchCV, RepeatedStratifiedKFold\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import roc_curve, auc\n",
    "import random\n",
    "import gc\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MVWfaDrOUYOV"
   },
   "outputs": [],
   "source": [
    "def get_dataset_params(dataset_name, cp_filename='imagenet_l2_3_0.pt', pp='pixel'):\n",
    "  img_dir = dataset_name + '_500/'\n",
    "  #result_dir = 'drive/MyDrive/Results/sanity/res_adv_linf8_pixel/' + 'res_' + dataset_name + '/'\n",
    "  #result_dir = 'drive/MyDrive/Results/sanity/' + cp_filename.split('.')[0] + '_' + pp + '/res_' + dataset_name + '/'\n",
    "  result_dir = 'drive/MyDrive/Results/sanity/res_adv_linf8_seg/' + '/res_' + dataset_name + '/'\n",
    "\n",
    "  return result_dir, img_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ROryJLnoS3fc"
   },
   "outputs": [],
   "source": [
    "!cp /content/drive/MyDrive/Results/sanity/datasets/*.* .\n",
    "!unzip imagenette_500.zip\n",
    "!unzip pvoc_500.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "mq9bmg-6THH0",
    "outputId": "4cc8c8f3-29fe-4983-f061-0b27a6e7b9d8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Oxford-IIIT Pets dataset downloaded and extracted successfully.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import urllib.request\n",
    "import tarfile\n",
    "\n",
    "def download_and_extract(url, extract_dir):\n",
    "    # Create the directory if it doesn't exist\n",
    "    if not os.path.exists(extract_dir):\n",
    "        os.makedirs(extract_dir)\n",
    "\n",
    "    # Download the tar file\n",
    "    tar_filename = os.path.join(extract_dir, \"dataset.tar.gz\")\n",
    "    urllib.request.urlretrieve(url, tar_filename)\n",
    "\n",
    "    # Extract the tar file\n",
    "    with tarfile.open(tar_filename, 'r:gz') as tar:\n",
    "        tar.extractall(path=extract_dir)\n",
    "\n",
    "    # Remove the downloaded tar file\n",
    "    os.remove(tar_filename)\n",
    "\n",
    "# URL to download the dataset\n",
    "url = \"http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz\"\n",
    "\n",
    "# Directory to extract the dataset\n",
    "extract_dir = \"./oxpets\"\n",
    "\n",
    "# Download and extract the dataset\n",
    "download_and_extract(url, extract_dir)\n",
    "\n",
    "print(\"Oxford-IIIT Pets dataset downloaded and extracted successfully.\")\n",
    "!rm -rf oxpets_500\n",
    "!mkdir oxpets_500\n",
    "!cp oxpets/images/* oxpets_500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aq3YfECQTLGi"
   },
   "outputs": [],
   "source": [
    "# Define the directory path\n",
    "directory = \"oxpets_500/\"\n",
    "\n",
    "# List all image files in the directory (e.g., .jpg, .png)\n",
    "image_files = [f for f in os.listdir(directory) if f.endswith(('.jpg', '.png'))]\n",
    "\n",
    "# Select 50 random files to keep\n",
    "files_to_keep = random.sample(image_files, 500)\n",
    "\n",
    "# Determine which files to delete\n",
    "files_to_delete = set(image_files) - set(files_to_keep)\n",
    "\n",
    "# Delete the files\n",
    "for file_name in files_to_delete:\n",
    "    file_path = os.path.join(directory, file_name)\n",
    "    os.remove(file_path)\n",
    "    print(f\"Deleted {file_name}\")\n",
    "\n",
    "print(\"Done!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "G_lmGpHXrY2s"
   },
   "outputs": [],
   "source": [
    "!mkdir /content/drive/MyDrive/Results/sanity/res_adv_l2e3_seg/res_imagenette\n",
    "!mkdir /content/drive/MyDrive/Results/sanity/res_adv_l2e3_seg/res_oxpets\n",
    "!mkdir /content/drive/MyDrive/Results/sanity/res_adv_l2e3_seg/res_pvoc\n",
    "!mkdir /content/drive/MyDrive/Results/sanity/res_adv_linf8_seg/res_imagenette\n",
    "!mkdir /content/drive/MyDrive/Results/sanity/res_adv_linf8_seg/res_oxpets\n",
    "!mkdir /content/drive/MyDrive/Results/sanity/res_adv_linf8_seg/res_pvoc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BP0qJ5nhZMrH"
   },
   "outputs": [],
   "source": [
    "def select_n_random_pixels(img_size=(224,224), n=50):\n",
    "    width, height = img_size[0], img_size[1]\n",
    "    #rows, columns, _ = img.shape\n",
    "\n",
    "    pixel_indices = [(y, x) for y in range(height) for x in range(width)]  # Correct order: (row, column)\n",
    "    selected_pixels = random.sample(pixel_indices, n)\n",
    "    return selected_pixels\n",
    "\n",
    "def divide_into_halves(patch_size):\n",
    "    h1 = patch_size // 2\n",
    "    h2 = patch_size - h1\n",
    "    return h1, h2\n",
    "\n",
    "def generate_perturbed_images(img, coords, patch_size = 9, pert_type = \"IT\"):\n",
    "    height, width,_ = img.shape\n",
    "\n",
    "    pert_images = []\n",
    "    random_pixel_value = [random.randint(int(np.min(img)), int(np.max(img))) for _ in range(3)]\n",
    "    for coord in coords:\n",
    "        h1, h2 = divide_into_halves(patch_size)\n",
    "        mask = np.zeros((height, width), dtype=np.uint8)\n",
    "        y_min = max(int(coord[1]-(h1)),0)\n",
    "        y_max = min(int(coord[1]+(h2)),height-1)\n",
    "\n",
    "        x_min = max(int(coord[0]-(h1)),0)\n",
    "        x_max = min(int(coord[0]+(h2)),width-1)\n",
    "\n",
    "        #damaged_image[coord]=0\n",
    "\n",
    "        if pert_type=='IT':\n",
    "            mask[x_min:x_max, y_min:y_max] = 255\n",
    "            damaged_image = copy.deepcopy(img)\n",
    "            damaged_image[x_min:x_max, y_min:y_max] = [0, 0, 0]\n",
    "            damaged_image = damaged_image.astype(np.uint8)\n",
    "            try:\n",
    "                output = cv2.inpaint(damaged_image, mask, 5, cv2.INPAINT_TELEA)\n",
    "            except Exception as e:\n",
    "                print(\"Error during TELEA cv2.inpaint:\", e)\n",
    "\n",
    "        elif pert_type=='IN':\n",
    "            mask[x_min:x_max, y_min:y_max] = 255\n",
    "            damaged_image = copy.deepcopy(img)\n",
    "            damaged_image[x_min:x_max, y_min:y_max] = [0, 0, 0]\n",
    "            damaged_image = damaged_image.astype(np.uint8)\n",
    "            try:\n",
    "                output = cv2.inpaint(damaged_image, mask, 5, cv2.INPAINT_NS)\n",
    "            except Exception as e:\n",
    "                print(\"Error during NS cv2.inpaint:\", e)\n",
    "\n",
    "        elif pert_type=='U0':\n",
    "            output = copy.deepcopy(img)\n",
    "            output[x_min:x_max, y_min:y_max] = np.min(img) #[0, 0, 0]\n",
    "\n",
    "        elif pert_type=='U.5':\n",
    "            output = copy.deepcopy(img)\n",
    "            output[x_min:x_max, y_min:y_max] = np.mean(img) #[127, 127, 127]\n",
    "\n",
    "        elif pert_type=='U1':\n",
    "            output = copy.deepcopy(img)\n",
    "            output[x_min:x_max, y_min:y_max] = np.max(img) #[255, 255, 255]\n",
    "\n",
    "        elif pert_type=='FR':\n",
    "            output = copy.deepcopy(img)\n",
    "            output[x_min:x_max, y_min:y_max] = random_pixel_value\n",
    "\n",
    "        elif pert_type=='G3':\n",
    "            output = filters.gaussian(img, 0.3, channel_axis=2, preserve_range=True)\n",
    "\n",
    "        elif pert_type=='G9':\n",
    "            output = filters.gaussian(img, 0.9, channel_axis=2, preserve_range=True)\n",
    "\n",
    "        elif pert_type=='G15':\n",
    "            output = filters.gaussian(img, 1.5, channel_axis=2, preserve_range=True)\n",
    "        else:\n",
    "            print(\"Unknown Type\")\n",
    "            return None\n",
    "\n",
    "        pert_image = copy.deepcopy(img)\n",
    "        pert_image[coord[0], coord[1]] = output[coord[0], coord[1]]\n",
    "        pert_images.append(pert_image)\n",
    "\n",
    "    return pert_images\n",
    "\n",
    "def preprocess_images_cv2(img_list, img_size, device):\n",
    "    preprocessed_images = []\n",
    "\n",
    "    # Define the preprocessing steps\n",
    "    preprocess = transforms.Compose([\n",
    "        transforms.Resize(256),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ])\n",
    "\n",
    "    for img in img_list:\n",
    "        # Convert the image to a PIL Image\n",
    "        img_pil = Image.fromarray(img)\n",
    "\n",
    "        # Apply the preprocessing steps to the image\n",
    "        input_tensor = preprocess(img_pil)\n",
    "        input_batch = input_tensor.unsqueeze(0)  # Create a mini-batch as expected by the model\n",
    "\n",
    "        preprocessed_images.append(input_batch.to(device))\n",
    "\n",
    "    # Stack the mini-batches into a single batch\n",
    "    return torch.cat(preprocessed_images)\n",
    "\n",
    "def preprocess_image_cv2(img, img_size, device):\n",
    "    # Convert the image to a PIL Image\n",
    "    img_pil = Image.fromarray(img)\n",
    "\n",
    "    # Define the preprocessing steps\n",
    "    preprocess = transforms.Compose([\n",
    "        transforms.Resize(256),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ])\n",
    "\n",
    "    # Apply the preprocessing steps to the image\n",
    "    input_tensor = preprocess(img_pil)\n",
    "    input_batch = input_tensor.unsqueeze(0)  # Create a mini-batch as expected by the model\n",
    "\n",
    "    return input_batch.to(device)\n",
    "\n",
    "def predict_class_probability(model, input_batch, class_idx):\n",
    "    with torch.no_grad():\n",
    "        output = model(input_batch)\n",
    "\n",
    "    # If the output is a tuple, select the first element\n",
    "    if isinstance(output, tuple):\n",
    "        output = output[0]\n",
    "\n",
    "    # Apply softmax to get probabilities\n",
    "    probabilities = torch.nn.functional.softmax(output, dim=1)\n",
    "\n",
    "    # Extract the probability of the specific class for each image in the batch\n",
    "    class_probabilities = probabilities[:, class_idx].tolist()\n",
    "\n",
    "    return class_probabilities\n",
    "\n",
    "\n",
    "def generate_blur_images_seg(img, mask, sigma):\n",
    "    blur_img = filters.gaussian(img, sigma, channel_axis=2)\n",
    "    blur_img = (blur_img * 255).astype(np.uint8)  # Rescale to 0-255 range\n",
    "    mask3d = np.stack([mask] * 3, axis=-1)\n",
    "    perturbed_image = np.where(mask3d == 255, blur_img, img)\n",
    "\n",
    "    return perturbed_image\n",
    "\n",
    "\n",
    "def generate_perturbed_images_seg(img, segments, seg_ids, pert_type=\"IT\", max_pert_imgs=50):\n",
    "    pert_images = []\n",
    "\n",
    "    for segment_val in seg_ids:\n",
    "        mask = np.zeros(segments.shape, dtype=np.uint8)\n",
    "        mask[segments == segment_val] = 255\n",
    "\n",
    "        damaged_image = copy.deepcopy(img)\n",
    "        damaged_image[segments == segment_val] = 0\n",
    "\n",
    "        if pert_type == 'IT':\n",
    "            output = cv2.inpaint(damaged_image, mask, 5, cv2.INPAINT_TELEA)\n",
    "        elif pert_type == 'IN':\n",
    "            output = cv2.inpaint(damaged_image, mask, 5, cv2.INPAINT_NS)\n",
    "        elif pert_type == 'U0':\n",
    "            output = copy.deepcopy(img)\n",
    "            output[segments == segment_val] = [0, 0, 0]\n",
    "        elif pert_type == 'U.5':\n",
    "            output = copy.deepcopy(img)\n",
    "            output[segments == segment_val] = [127, 127, 127]\n",
    "        elif pert_type == 'U1':\n",
    "            output = copy.deepcopy(img)\n",
    "            output[segments == segment_val] = [255, 255, 255]\n",
    "        elif pert_type == 'FR':\n",
    "            output = copy.deepcopy(img)\n",
    "            random_pixel = [random.randint(0, 255) for _ in range(3)]\n",
    "            output[segments == segment_val] = random_pixel\n",
    "        elif pert_type == 'G3':\n",
    "            output = generate_blur_images_seg(img, mask, sigma=0.3)\n",
    "        elif pert_type == 'G9':\n",
    "            output = generate_blur_images_seg(img, mask, sigma=0.9)\n",
    "        elif pert_type == 'G15':\n",
    "            output = generate_blur_images_seg(img, mask, sigma=1.5)\n",
    "        else:\n",
    "            print(\"Unknown Type\")\n",
    "            return None\n",
    "\n",
    "        pert_images.append(output)\n",
    "\n",
    "    return pert_images\n",
    "\n",
    "\n",
    "def predict_top_class(model, input_batch):\n",
    "    # Set the model to evaluation mode\n",
    "    model.eval()\n",
    "\n",
    "    # No need to track gradients during inference\n",
    "    with torch.no_grad():\n",
    "        # Pass the input batch through the model\n",
    "        output = model(input_batch)\n",
    "\n",
    "        # If the model's output is a tuple, select the first element (this is common in some models)\n",
    "        if isinstance(output, tuple):\n",
    "            output = output[0]\n",
    "\n",
    "        # Apply softmax to get the probabilities\n",
    "        probabilities = F.softmax(output, dim=1)\n",
    "\n",
    "        # Get the top predicted class and its probability\n",
    "        top_prob, top_class = torch.max(probabilities, dim=1)\n",
    "\n",
    "        # Return the top probability and class as Python scalars\n",
    "        return top_prob.item(), top_class.item()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6_bbaBBrsPTx"
   },
   "outputs": [],
   "source": [
    "pert_types = [\"IT\", \"IN\", \"U0\", \"U1\", \"U.5\",\"FR\", \"G3\", \"G9\", \"G15\"]\n",
    "\n",
    "model_names = [\"resnet50\", \"inceptionv3\", \"xception\"]\n",
    "datasets = [\"oxpets\", \"imagenette\", \"pvoc\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LiE24hTsT22Y"
   },
   "outputs": [],
   "source": [
    "num_pxls = 50\n",
    "batch_size = 10\n",
    "img_size = (224,224)\n",
    "\n",
    "pp = 'seg'\n",
    "\n",
    "for dataset in datasets:\n",
    "  result_dir, img_dir = get_dataset_params(dataset, cp_filename=cp_filename, pp=pp)\n",
    "  print(result_dir + \" \" + img_dir)\n",
    "\n",
    "  img_filenames = os.listdir(img_dir)\n",
    "  #img_filenames = img_filenames[0:100] #uncomment to run in test mode\n",
    "\n",
    "  for img_filename in img_filenames:\n",
    "    pkl_filename = result_dir + dataset + \"_\" + img_filename.split('.')[0] + '.pkl'\n",
    "    if os.path.exists(pkl_filename):\n",
    "      continue\n",
    "\n",
    "    image_path = os.path.join(img_dir, img_filename)\n",
    "    img = cv2.imread(image_path)\n",
    "    img = cv2.resize(img, img_size)\n",
    "    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "\n",
    "    segments = quickshift(img, kernel_size=5, max_dist=6, ratio=0.5)\n",
    "    unique_segments = np.unique(segments)\n",
    "    if len(unique_segments) > num_pxls:\n",
    "      selected_segments = np.random.choice(unique_segments, num_pxls, replace=False)\n",
    "    else:\n",
    "      selected_segments = unique_segments\n",
    "\n",
    "    input_batch = preprocess_image_cv2(img, img_size, device)\n",
    "    # Ensure model is on the same device\n",
    "    m = m.to(device)\n",
    "    m.eval()\n",
    "\n",
    "    prob0, top_pred_class = predict_top_class(m, input_batch)\n",
    "    print(f\"Top predicted class ID: {top_pred_class}, Probability: {prob0:.4f}\")\n",
    "    pert_pxl_list = select_n_random_pixels(img_size, num_pxls)\n",
    "\n",
    "    run_dict = {}\n",
    "    for pert_type in pert_types:\n",
    "      pert_imgs = generate_perturbed_images_seg(img, segments, seg_ids= np.unique(selected_segments), pert_type=pert_type) #generate_perturbed_images(img, pert_pxl_list, pert_type=pert_type)\n",
    "      pert_imgs = preprocess_images_cv2(pert_imgs, img_size, device)\n",
    "      class_probabilities = predict_class_probability(m, pert_imgs, top_pred_class)\n",
    "      print(class_probabilities)\n",
    "      run_dict[pert_type] = np.array(class_probabilities)\n",
    "      del pert_imgs, class_probabilities\n",
    "      gc.collect()\n",
    "\n",
    "    img_dict = {\n",
    "        'run_info': run_dict,\n",
    "        'prob0': prob0,\n",
    "        'img_path': dataset + \"/\" + image_path}\n",
    "\n",
    "    print(pkl_filename)\n",
    "    with open(pkl_filename, 'wb') as f1:\n",
    "      pickle.dump(img_dict, f1)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
