{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from utils.data import get_dataloaders, AdversarialDataset\n",
    "from utils.models import get_model, Mask, MaskedClf\n",
    "from torch.utils.data import DataLoader\n",
    "from utils.data import AdversarialDataset\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import foolbox"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "base_model=get_model('resnet20')\n",
    "base_model = base_model.to(device)\n",
    "base_model.load_state_dict(torch.load(\"trained_models/resnet20/clean.pt\"))\n",
    "base_model.eval()\n",
    "fmodel = foolbox.models.PyTorchModel(base_model, bounds=(0,1))\n",
    "if not os.path.exists('class_specific'):\n",
    "    os.makedirs('class_specific', exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss=torch.nn.CrossEntropyLoss()\n",
    "dataloaders=get_dataloaders('cifar10', 128, 1, shuffle_train=True, shuffle_test=False, unnorm=True)\n",
    "dataset=AdversarialDataset(fmodel, 'resnet20', 'FMN', dataloaders['train'], 32, 'train')\n",
    "dataset.clean_imgs=dataset.clean_imgs[dataset.labels.argsort()]\n",
    "dataset.labels=dataset.labels[dataset.labels.argsort()]\n",
    "train_dataloader= DataLoader(dataset, batch_size=5000, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for x,xadv,y in train_dataloader:\n",
    "    print(\"Class: \", y.unique()[0].item())\n",
    "    losses=[]\n",
    "    x=x.to(device)\n",
    "    y=y.to(device)\n",
    "    correct=(base_model(x).argmax(-1)==y)\n",
    "    x=x\n",
    "    x=x[correct]\n",
    "    y=y[correct]\n",
    "    model=MaskedClf(Mask((3, 32, 32)).to(device), base_model)\n",
    "    for p in model.clf.parameters():\n",
    "        p.requires_grad=False\n",
    "    model.mask.train()\n",
    "    optimizer=torch.optim.Adam(model.mask.parameters(), lr=0.01)\n",
    "    for e in range(5000):\n",
    "        print(e, end='\\r')\n",
    "        out=model(x)\n",
    "        l=loss(out, y)\n",
    "        penalty=model.mask.M.abs().sum()\n",
    "        l+=penalty*0.01\n",
    "        losses.append(l.item())\n",
    "        optimizer.zero_grad()\n",
    "        l.backward()\n",
    "        optimizer.step()\n",
    "        model.mask.M.data.clamp_(0., 1.)\n",
    "        c=y[0].cpu().item()\n",
    "        if(e>500 and abs(l.item()-np.mean(losses[-20:]))<1e-5):\n",
    "            print((model(x).argmax(-1)==y).sum(), e)\n",
    "            mask=torch.fft.fftshift(model.mask.M.detach().cpu())\n",
    "            mask=mask.squeeze().numpy()\n",
    "            np.save(f'class_specific/{c}.npy', mask)\n",
    "            plt.figure()\n",
    "            plt.imshow(mask[0], cmap=\"Blues\")\n",
    "            plt.colorbar()\n",
    "            plt.savefig(f'class_specific/{c}R.png')\n",
    "            plt.close()\n",
    "            plt.figure()\n",
    "            plt.imshow(mask[1], cmap=\"Blues\")\n",
    "            plt.colorbar()\n",
    "            plt.savefig(f'class_specific/{c}G.png')\n",
    "            plt.close()\n",
    "            plt.figure()\n",
    "            plt.imshow(mask[2], cmap=\"Blues\")\n",
    "            plt.colorbar()\n",
    "            plt.savefig(f'class_specific/{c}B.png')\n",
    "            plt.close()\n",
    "            break\n",
    "    del model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adv_test_dataloader=DataLoader(AdversarialDataset(fmodel, 'resnet20', 'FMN', dataloaders['test'], 32, 'test'), batch_size=1000, shuffle=False)\n",
    "correct=0\n",
    "adversarial=0\n",
    "masked=0\n",
    "for x,xadv,y in adv_test_dataloader:\n",
    "    x=x.to(device)\n",
    "    xadv=xadv.to(device)\n",
    "    y=y.to(device)\n",
    "    c=y[0].cpu().item()\n",
    "    clean_out=base_model(x)\n",
    "    correct_images=(clean_out.argmax(-1)==y)\n",
    "    correct+=correct_images.sum()\n",
    "    adv_out=base_model(xadv[correct_images])\n",
    "    adv_images=(adv_out.argmax(-1)!=y[correct_images])\n",
    "    adversarial+=adv_images.sum()\n",
    "    masked_model=MaskedClf(Mask((3, 32, 32)).to(device), base_model)\n",
    "    masked_model.mask.M.data=torch.fft.ifftshift(torch.tensor(np.load(f'class_specific/{c}.npy')))\n",
    "    masked_model.mask=masked_model.mask.to(device)\n",
    "    xadv=xadv[correct_images]\n",
    "    masked_out=masked_model(xadv)\n",
    "    masked_images=(masked_out.argmax(-1)!=y[correct_images])\n",
    "    masked+=masked_images.sum()\n",
    "print(\"Correctly classified: \", correct.item(), \"Adversarial: \", adversarial.item(), \"Adversarial after using the mask:\", masked.item())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
