{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71c6fb29",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "from torch.optim.lr_scheduler import MultiStepLR\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import utils\n",
    "\n",
    "from utils.datasets import get_dataloaders\n",
    "from models.resnet import ResNet18\n",
    "\n",
    "from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method\n",
    "from cleverhans.torch.attacks.projected_gradient_descent import projected_gradient_descent\n",
    "from easydict import EasyDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b1b5b43",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca3bfdb1",
   "metadata": {},
   "source": [
    "## Functions for robust training and eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24adda67",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pgd_l2(model, X, y, epsilon=[1., 4., 8.], num_iter=10):\n",
    "    \"\"\" Construct FGSM adversarial examples on the examples X\"\"\"\n",
    "    # randomly choose epsilon\n",
    "    epsilon = np.random.choice(epsilon)\n",
    "    \n",
    "    # choose a random starting point with length epsilon / 2\n",
    "    delta = torch.rand_like(X, requires_grad=True) \n",
    "    norm = torch.linalg.norm(delta.flatten())\n",
    "    delta.data = epsilon * delta.data / norm / 2\n",
    "    \n",
    "    alpha = 2.5 * epsilon / num_iter # fixed step size of 2.5*epsilon/100 as in https://arxiv.org/pdf/1706.06083.pdf\n",
    "    for t in range(num_iter):\n",
    "        loss = nn.CrossEntropyLoss()(model(X + delta), y)\n",
    "        loss.backward()\n",
    "        \n",
    "        # take a step\n",
    "        step = delta.grad.detach()\n",
    "        step = alpha * step / torch.linalg.norm(step.flatten())  \n",
    "        delta.data = delta.data + step\n",
    "        \n",
    "        # project on the epsilon ball around X if necessary\n",
    "        norm = torch.linalg.norm(delta.flatten())\n",
    "        if norm > epsilon:\n",
    "            delta.data = epsilon * delta.data / norm\n",
    "        \n",
    "        # next iteration\n",
    "        delta.grad.zero_()\n",
    "    return delta.detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a821fabe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cleverhans_eval_l2(net, testloader, eps):\n",
    "    # Evaluate on clean and adversarial data\n",
    "    net.eval()\n",
    "    report = EasyDict(nb_test=0, correct=0, correct_fgm=0, correct_pgd=0)\n",
    "    for idx, (x, y) in enumerate(testloader):\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        x_fgm = fast_gradient_method(net, x, eps, 2)\n",
    "        x_pgd = projected_gradient_descent(net, x, eps, 2.5*eps/100, 100, 2)\n",
    "        _, y_pred = net(x).max(1)  # model prediction on clean examples\n",
    "        _, y_pred_fgm = net(x_fgm).max(\n",
    "                1\n",
    "        )  # model prediction on FGM adversarial examples\n",
    "        _, y_pred_pgd = net(x_pgd).max(\n",
    "                1\n",
    "        )  # model prediction on PGD adversarial examples\n",
    "        report.nb_test += y.size(0)\n",
    "        report.correct += y_pred.eq(y).sum().item()\n",
    "        report.correct_fgm += y_pred_fgm.eq(y).sum().item()\n",
    "        report.correct_pgd += y_pred_pgd.eq(y).sum().item()\n",
    "        if idx > 9:\n",
    "            break # 1280 samples only\n",
    "    print(\n",
    "        \"test acc on clean examples (%): {:.3f}\".format(\n",
    "            report.correct / report.nb_test * 100.0\n",
    "        )\n",
    "    )\n",
    "    print(\n",
    "        \"test acc on FGM adversarial examples (%): {:.3f}\".format(\n",
    "            report.correct_fgm / report.nb_test * 100.0\n",
    "        )\n",
    "    )\n",
    "    print(\n",
    "        \"test acc on PGD adversarial examples (%): {:.3f}\".format(\n",
    "            report.correct_pgd / report.nb_test * 100.0\n",
    "        )\n",
    "    )\n",
    "    return 1 - report.correct_pgd / report.nb_test"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7cee311",
   "metadata": {},
   "source": [
    "## Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64315f9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainloader, testloader = get_dataloaders(\"simple_word_distractor_mnist\", batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2edff33f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ResNet18(in_channel=1)\n",
    "model.to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr=1e-1)\n",
    "scheduler = MultiStepLR(optimizer, milestones=[3], gamma=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fbbd60b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ce_loss = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "for i_epoch in range(9): #15?    \n",
    "    model.train()\n",
    "    train_loss = 0\n",
    "    train_zero_one_loss = 0\n",
    "    for img, label in tqdm(trainloader):\n",
    "        img, label = img.to(device), label.to(device)\n",
    "        delta = pgd_l2(model, img, label) # adversarial perturbation\n",
    "        pred = model(img + delta)\n",
    "        optimizer.zero_grad()\n",
    "        loss = ce_loss(pred, label)\n",
    "        loss.backward()\n",
    "        train_loss += loss.item()  \n",
    "        train_zero_one_loss += (pred.softmax(dim=1).argmax(dim=1) != label).sum().item()\n",
    "        optimizer.step()\n",
    "    average_loss, acc = utils.datasets.test(model, testloader, device)\n",
    "    print(f'Epoch {i_epoch}. Avg. Loss: {train_loss / len(trainloader.dataset)}. Avg. Val Loss: {average_loss}. Acc.: {1-train_zero_one_loss / len(trainloader.dataset)}.  Val Acc. {acc}')\n",
    "    scheduler.step()\n",
    "    \n",
    "    torch.save(model.state_dict(), f'../saved_models/mnist_simple_word_distractor_adv_robust_l2_epoch_{i_epoch}.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc0e3e0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), f'../saved_models/mnist_simple_word_distractor_adv_robust_l2.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54864cac",
   "metadata": {},
   "source": [
    "## Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f670331",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.load_state_dict( torch.load(f'../saved_models/mnist_simple_word_distractor_adv_robust_l2.pth'))\n",
    "model.eval()\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fd2b7d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "for epsilon in [0.001, 0.5, 1, 2, 4, 8, 10, 20]:\n",
    "    print(f' --------- {epsilon} --------- ')\n",
    "    cleverhans_eval_l2(model, testloader, epsilon)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
