{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import yaml\n",
    "from ml_collections import ConfigDict\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, Subset\n",
    "\n",
    "from qrcp.experiments.image_utils.cifar_resnet import ResNet\n",
    "from qrcp.experiments.image_utils.architectures import get_architecture\n",
    "from qrcp.experiments.image_utils.image_datasets import get_dataset\n",
    "from qrcp.helpers.lightner import ModelManager, Output\n",
    "from qrcp.robust.smoothing import standard_l2_norm\n",
    "from qrcp.helpers.storage import smooth_prediction_filename\n",
    "from torchattacks import PGDL2\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = \"cifar10\"\n",
    "model_sigma = 0.25\n",
    "n_datapoints = 2048\n",
    "smoothing_sigma = 0.25\n",
    "n_samples = 10000\n",
    "\n",
    "r = 0.12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "general_config = yaml.safe_load(open(\"../conf/general.yaml\", \"r\"))\n",
    "conf = ConfigDict(general_config[\"general\"])\n",
    "default_models = general_config[\"models\"]\n",
    "model_name = default_models[dataset_name]\n",
    "\n",
    "# Loading model\n",
    "model_file = os.path.join(conf.models_dir, dataset_name, model_name, f\"noise_{model_sigma}\", \"checkpoint.pth.tar\")\n",
    "model_dict = torch.load(model_file)\n",
    "model = get_architecture(model_dict[\"arch\"], dataset_name)\n",
    "model.load_state_dict(model_dict[\"state_dict\"])\n",
    "model_obj = ModelManager(model, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "dataset size = 10000\n",
      "dataset size = 2048\n"
     ]
    }
   ],
   "source": [
    "# Loading dataset\n",
    "dataset = get_dataset('cifar10', 'test', root=conf.dataset_dir)\n",
    "print(f\"dataset size = {len(dataset)}\")\n",
    "subset_indices = list(range(0, n_datapoints, ))\n",
    "dataset = Subset(dataset, subset_indices)\n",
    "print(f\"dataset size = {len(dataset)}\")\n",
    "\n",
    "test_dataset = DataLoader(dataset, batch_size=128, shuffle=False, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "def smooth_predict(self, test_loader, n_samples=100, smoothing_function=None):\n",
    "    if smoothing_function is None:\n",
    "        smoothing_function = lambda inputs: inputs\n",
    "    \n",
    "    self.model.eval()\n",
    "\n",
    "    y_true = []\n",
    "    y_pred = []\n",
    "    logits = []\n",
    "    for inputs, labels in tqdm(test_loader):\n",
    "        torch.cuda.empty_cache()\n",
    "        inputs = inputs.to(self.device)\n",
    "        labels = labels.to(self.device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            batch_outputs = []\n",
    "            for iter in range(n_samples):\n",
    "                s_inputs = smoothing_function(inputs)\n",
    "                outputs = self.model(s_inputs)\n",
    "                batch_outputs.append(outputs)\n",
    "            batch_outputs = torch.stack(batch_outputs).permute(1, 0, 2)\n",
    "            logits.append(batch_outputs.cpu())\n",
    "            y_true.append(labels)\n",
    "            _, max_class = batch_outputs.max(dim=2)\n",
    "            maj_vote, _ = max_class.mode()\n",
    "            y_pred.append(maj_vote)\n",
    "    y_pred = torch.concat(y_pred)\n",
    "    y_true = torch.concat(y_true)\n",
    "    logits = torch.concat(logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "attack = PGDL2(model, eps=r, alpha=1/255, steps=20, random_start=True)\n",
    "\n",
    "for inputs, labels in test_dataset:\n",
    "    torch.cuda.empty_cache()\n",
    "    inputs = inputs.to(device)\n",
    "    labels = labels.to(device)\n",
    "    \n",
    "    adv_inputs = attack(inputs, labels)\n",
    "    break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    adv_outputs = model(adv_inputs)\n",
    "    outputs = model(inputs)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(86, device='cuda:0')"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(adv_outputs.argmax(1) == labels).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0924, 0.0803, 0.0992, 0.1200, 0.0828, 0.1013, 0.0942, 0.0988, 0.1171,\n",
       "        0.0864, 0.1200, 0.0913, 0.1200, 0.1175, 0.1011, 0.0950, 0.0859, 0.0882,\n",
       "        0.0808, 0.0842, 0.1103, 0.1027, 0.0999, 0.0786, 0.1200, 0.1200, 0.0781,\n",
       "        0.1200, 0.1077, 0.0790, 0.0836, 0.1200, 0.1200, 0.1059, 0.0859, 0.0868,\n",
       "        0.1200, 0.1200, 0.0971, 0.1200, 0.1200, 0.0811, 0.0804, 0.0786, 0.0867,\n",
       "        0.0787, 0.1200, 0.1200, 0.0792, 0.0931, 0.0876, 0.1200, 0.0851, 0.0802,\n",
       "        0.0785, 0.0827, 0.0989, 0.0914, 0.0983, 0.1200, 0.0783, 0.1200, 0.0822,\n",
       "        0.1200, 0.0942, 0.0783, 0.0930, 0.0919, 0.1200, 0.1200, 0.0828, 0.0783,\n",
       "        0.0828, 0.1062, 0.1103, 0.1191, 0.1052, 0.0829, 0.1091, 0.1002, 0.1200,\n",
       "        0.0788, 0.0824, 0.0917, 0.0846, 0.0852, 0.0979, 0.1010, 0.0783, 0.1200,\n",
       "        0.1075, 0.1003, 0.0784, 0.0842, 0.0834, 0.1200, 0.1200, 0.0888, 0.1200,\n",
       "        0.0911, 0.0788, 0.1200, 0.0917, 0.1063, 0.0824, 0.1200, 0.0933, 0.1168,\n",
       "        0.1153, 0.1200, 0.0784, 0.0795, 0.1200, 0.1200, 0.0882, 0.1200, 0.0894,\n",
       "        0.0855, 0.1108, 0.1083, 0.1168, 0.0902, 0.1088, 0.0976, 0.0935, 0.0831,\n",
       "        0.1200, 0.0782], device='cuda:0')"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(adv_inputs - inputs).norm(p=2, dim=(1, 2, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(96, device='cuda:0')"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(outputs.argmax(1) == labels).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for data\n",
    "test_dataset"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
