{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cfba65e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "#Certifying ResNet_4B requires large GPU memory, I would suggest moving to the CPU\n",
    "device = 'cuda:0' if torch.cuda.is_available else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "201140cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from auto_LiRPA import BoundedModule, PerturbationLpNorm, BoundedTensor\n",
    "from auto_LiRPA.utils import get_spec_matrix\n",
    "from cert_util import min_correct_with_eps, load_data, DeltaWrapper\n",
    "\n",
    "from model_defs import resnet2b, resnet4b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "79ef4ad0",
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar10_mean = (0.4914, 0.4822, 0.4465)  # np.mean(train_set.train_data, axis=(0,1,2))/255\n",
    "cifar10_std = (0.2471, 0.2435, 0.2616)  # np.std(train_set.train_data, axis=(0,1,2))/255\n",
    "\n",
    "mu = torch.tensor(cifar10_mean).view(3,1,1).to(device)\n",
    "std = torch.tensor(cifar10_std).view(3,1,1).to(device)\n",
    "\n",
    "def normalize(X, device):\n",
    "    result = (X.to(device) - mu.to(device))/std.to(device)\n",
    "    return result\n",
    "\n",
    "def bounded_results(eps,bounded_model):\n",
    "    # normalization for the input noise\n",
    "    ptb = PerturbationLpNorm(x_L=-eps/std.view(1,3,1,1).expand(1,3,32,32), x_U=eps/std.view(1,3,1,1).expand(1,3,32,32))\n",
    "    bounded_delta = BoundedTensor(delta, ptb)\n",
    "    \n",
    "    final_name = bounded_model.final_name\n",
    "    input_name = '/1' # '/input.1' \n",
    "\n",
    "    result = bounded_model.compute_bounds(\n",
    "        x=(new_image,bounded_delta), method='backward', C=C,\n",
    "        return_A=True, \n",
    "        needed_A_dict={ final_name: [input_name] },\n",
    "    )\n",
    "    lower, upper, A_dict = result\n",
    "    lA = A_dict[final_name][input_name]['lA']\n",
    "    uA = A_dict[final_name][input_name]['uA']\n",
    "\n",
    "    lb = lower - ptb.concretize(delta, lA, sign=-1)\n",
    "    ub = upper - ptb.concretize(delta, uA, sign=1)\n",
    "\n",
    "\n",
    "    lA = torch.reshape(lA,(eval_num, num_cls-1,-1))\n",
    "    return lA,lb, lower"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98316f44",
   "metadata": {},
   "source": [
    "### Reproducibility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c037eec5",
   "metadata": {},
   "outputs": [],
   "source": [
    "my_seed = 0\n",
    "torch.cuda.empty_cache()\n",
    "torch.manual_seed(my_seed)\n",
    "random.seed(my_seed)\n",
    "np.random.seed(my_seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a3a24ec",
   "metadata": {},
   "source": [
    "### Hyperparameters "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f1002863",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_num = 100\n",
    "num_cls = 10\n",
    "adv_e = 1\n",
    "\n",
    "# CIFAR models, one can choose from 'resnet2b', 'resnet4b'\n",
    "model_name = \"resnet2b\" \n",
    "adv_train_norm = 8"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abfce727",
   "metadata": {},
   "source": [
    "### Loadding the model for certification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c72c77df",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = eval(model_name)()\n",
    "net.load_state_dict(torch.load('./model_weights/'\n",
    "                                +model_name+ '_' +\n",
    "                                 str(adv_train_norm) +\n",
    "                                '.pth', map_location=device)['state_dict'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d980919",
   "metadata": {},
   "source": [
    "### Loading a batch of data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f2db0f5d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "new_image, new_label = load_data(num_imgs=eval_num, random=True, dataset='CIFAR')\n",
    "new_image = normalize(new_image,device)\n",
    "C = get_spec_matrix(new_image,new_label.long(),10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "591372fe",
   "metadata": {},
   "source": [
    "### Model concretization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ba9573e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "eps = adv_e/255\n",
    "delta = torch.zeros_like(new_image[0]).unsqueeze(0)\n",
    "dummy_input = (new_image, delta)\n",
    "model = DeltaWrapper(net.to(device))\n",
    "bounded_model = BoundedModule(model, dummy_input)\n",
    "bounded_model.eval()\n",
    "final_name = bounded_model.final_name"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "424ca018",
   "metadata": {},
   "source": [
    "### Results and comparision "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "383ab5f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eps: 1\n",
      "Samp-wise Cert-ACC: 39%\n",
      "Set parameter Username\n",
      "Academic license - for non-commercial use only - expires 2023-05-24\n",
      "Set parameter TimeLimit to value 600\n",
      "UP-based Cert-ACC: 41.0%\n"
     ]
    }
   ],
   "source": [
    "print('Eps:', adv_e)\n",
    "\n",
    "alpha,beta,result = bounded_results(eps,bounded_model)\n",
    "samp_ACC = torch.sum(result.detach().cpu().min(axis=1)[0] > 0).numpy()\n",
    "print('Samp-wise Cert-ACC: {}%'.format(samp_ACC))\n",
    "\n",
    "label = new_label\n",
    "number_class = num_cls\n",
    "cert_ACC, delta = min_correct_with_eps(alpha, beta, eps, label, number_class=10, verbose=False, dataset='CIFAR')\n",
    "print('UP-based Cert-ACC: {}%'.format(cert_ACC))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d97eca7a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
