{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cfba65e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "#Certifying ResNet_4B requires large GPU memory, I would suggest moving to the CPU\n",
    "# device = 'cuda:0' if torch.cuda.is_available else 'cpu'\n",
    "device = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "201140cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from auto_LiRPA import BoundedModule, PerturbationLpNorm, BoundedTensor\n",
    "from auto_LiRPA.utils import get_spec_matrix\n",
    "from cert_util import min_correct_with_eps, load_data, DeltaWrapper\n",
    "\n",
    "# from model_defs import resnet2b, resnet4b\n",
    "# from resnext import ResNeXt_cifar \n",
    "# from densenet_no_bn import Densenet_cifar_wobn as Densenet\n",
    "from densenet import Densenet_cifar_32 as Densenet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "79ef4ad0",
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar10_mean = (0.4914, 0.4822, 0.4465)  # np.mean(train_set.train_data, axis=(0,1,2))/255\n",
    "cifar10_std = (0.2471, 0.2435, 0.2616)  # np.std(train_set.train_data, axis=(0,1,2))/255\n",
    "\n",
    "mu = torch.tensor(cifar10_mean).view(3,1,1).to(device)\n",
    "std = torch.tensor(cifar10_std).view(3,1,1).to(device)\n",
    "\n",
    "def normalize(X, device):\n",
    "    result = (X.to(device) - mu.to(device))/std.to(device)\n",
    "    return result\n",
    "\n",
    "def bounded_results(eps,bounded_model):\n",
    "    # normalization for the input noise\n",
    "    ptb = PerturbationLpNorm(x_L=-eps/std.view(1,3,1,1).expand(1,3,32,32), x_U=eps/std.view(1,3,1,1).expand(1,3,32,32))\n",
    "    bounded_delta = BoundedTensor(delta, ptb)\n",
    "    \n",
    "    final_name = bounded_model.final_name\n",
    "    input_name = '/1' # '/input.1' \n",
    "\n",
    "    result = bounded_model.compute_bounds(\n",
    "        x=(new_image,bounded_delta), method='backward', C=C,\n",
    "        return_A=True, \n",
    "        needed_A_dict={ final_name: [input_name] },\n",
    "    )\n",
    "    lower, upper, A_dict = result\n",
    "    lA = A_dict[final_name][input_name]['lA']\n",
    "    uA = A_dict[final_name][input_name]['uA']\n",
    "\n",
    "    lb = lower - ptb.concretize(delta, lA, sign=-1)\n",
    "    ub = upper - ptb.concretize(delta, uA, sign=1)\n",
    "\n",
    "\n",
    "    lA = torch.reshape(lA,(eval_num, num_cls-1,-1))\n",
    "    return lA,lb, lower"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98316f44",
   "metadata": {},
   "source": [
    "### Reproducibility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c037eec5",
   "metadata": {},
   "outputs": [],
   "source": [
    "my_seed = 4\n",
    "torch.cuda.empty_cache()\n",
    "torch.manual_seed(my_seed)\n",
    "random.seed(my_seed)\n",
    "np.random.seed(my_seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a3a24ec",
   "metadata": {},
   "source": [
    "### Hyperparameters "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f1002863",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_num = 5\n",
    "num_cls = 10\n",
    "adv_e = 1\n",
    "\n",
    "# CIFAR models, one can choose from 'resnet2b', 'resnet4b'\n",
    "model_name = \"Densenet\" \n",
    "# adv_train_norm = 8"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abfce727",
   "metadata": {},
   "source": [
    "### Loadding the model for certification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c72c77df",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DenseNet(\n",
       "  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (dense1): Sequential(\n",
       "    (0): Bottleneck(\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    )\n",
       "    (1): Bottleneck(\n",
       "      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    )\n",
       "  )\n",
       "  (trans1): Transition(\n",
       "    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))\n",
       "  )\n",
       "  (dense2): Sequential(\n",
       "    (0): Bottleneck(\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    )\n",
       "    (1): Bottleneck(\n",
       "      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    )\n",
       "    (2): Bottleneck(\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    )\n",
       "    (3): Bottleneck(\n",
       "      (bn1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    )\n",
       "  )\n",
       "  (trans2): Transition(\n",
       "    (bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (conv): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))\n",
       "  )\n",
       "  (dense3): Sequential(\n",
       "    (0): Bottleneck(\n",
       "      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    )\n",
       "    (1): Bottleneck(\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    )\n",
       "    (2): Bottleneck(\n",
       "      (bn1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    )\n",
       "    (3): Bottleneck(\n",
       "      (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    )\n",
       "  )\n",
       "  (bn): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (linear1): Linear(in_features=14336, out_features=512, bias=True)\n",
       "  (linear2): Linear(in_features=512, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = eval(model_name)()\n",
    "net.load_state_dict(torch.load('./model_weights/Densenet_cifar',map_location=device)['state_dict'])\n",
    "net.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d980919",
   "metadata": {},
   "source": [
    "### Loading a batch of data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f2db0f5d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2d3498953c3f486bbd6ecd2986fd242e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/170498071 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/cifar-10-python.tar.gz to ./data\n"
     ]
    }
   ],
   "source": [
    "new_image, new_label = load_data(num_imgs=eval_num, random=True, dataset='CIFAR')\n",
    "new_image = normalize(new_image,device)\n",
    "C = get_spec_matrix(new_image,new_label.long(),10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ec1a3f1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4a305da7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validating the model ACC: 2\n"
     ]
    }
   ],
   "source": [
    "print('Validating the model ACC:', torch.sum(torch.argmax(net(new_image),axis=1) == new_label.to(device)).detach().cpu().numpy()/new_label.shape[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "591372fe",
   "metadata": {},
   "source": [
    "### Model concretization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ba9573e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yizeng/anaconda3/envs/Research/lib/python3.9/site-packages/torch/nn/functional.py:2113: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  if size_prods == 1:\n",
      "/home/yizeng/anaconda3/envs/Research/lib/python3.9/site-packages/torch/onnx/symbolic_helper.py:677: UserWarning: ONNX export mode is set to inference mode, but operator batch_norm is set to training  mode. The model will be exported in inference, as specified by the export mode.\n",
      "  warnings.warn(\"ONNX export mode is set to \" + training_mode +\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "\n",
    "eps = adv_e/255\n",
    "delta = torch.zeros_like(new_image[0]).unsqueeze(0)\n",
    "dummy_input = (new_image, delta)\n",
    "model = DeltaWrapper(net.to(device))\n",
    "bounded_model = BoundedModule(model, dummy_input)\n",
    "bounded_model.eval()\n",
    "final_name = bounded_model.final_name"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "424ca018",
   "metadata": {},
   "source": [
    "### Results and comparision "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "383ab5f1",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 1\n",
      "Samp-wise Cert-ACC: 1%\n",
      "Set parameter Username\n",
      "Academic license - for non-commercial use only - expires 2023-05-24\n",
      "Set parameter TimeLimit to value 600\n",
      "UP-based Cert-ACC: 1.0%\n"
     ]
    }
   ],
   "source": [
    "print('Eps:', adv_e)\n",
    "\n",
    "alpha,beta,result = bounded_results(eps,bounded_model)\n",
    "samp_ACC = torch.sum(result.detach().cpu().min(axis=1)[0] > 0).numpy()/result.shape[0]\n",
    "print('Samp-wise Cert-ACC: {}%'.format(samp_ACC*100))\n",
    "\n",
    "label = new_label\n",
    "number_class = num_cls\n",
    "cert_ACC, delta = min_correct_with_eps(alpha, beta, eps, label, number_class=10, verbose=False, dataset='CIFAR')\n",
    "print('UP-based Cert-ACC: {}%'.format(100*cert_ACC/label.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d97eca7a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
