{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision.datasets\n",
    "import torchvision.transforms as transforms\n",
    "import numpy as np\n",
    "import os\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from pgd_attack import forward, perturb\n",
    "from sklearn.metrics import roc_curve, auc as auc_fn\n",
    "from dataset import *\n",
    "from utils import set_eval, set_train\n",
    "import seaborn as sns\n",
    "import torchvision\n",
    "from torchvision.utils import make_grid\n",
    "import pandas as pd\n",
    "from eval_utils import *\n",
    "import torchvision.models as models\n",
    "import os\n",
    "import pathlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda:0'\n",
    "datadir = './datasets'\n",
    "\n",
    "model = models.resnet50(pretrained=True)\n",
    "model = nn.DataParallel(model)\n",
    "model = model.to(device)\n",
    "\n",
    "model_dir = './experiments/celebahq256/'\n",
    "generation_result_dir = 'results/generation/celebahq256'\n",
    "pathlib.Path(generation_result_dir).mkdir(parents=True, exist_ok=True)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Using downloaded and verified file: ./datasets/test_32x32.mat\n",
      "Files already downloaded and verified\n",
      "{'norm': 'L2', 'eps': 100, 'steps': 400, 'step_size': 1.2}\n"
     ]
    }
   ],
   "source": [
    "K = 71\n",
    "# Generation\n",
    "cached_model = os.path.join(model_dir, 'steps{:05d}/model.pth'.format(K))\n",
    "model.load_state_dict(torch.load(cached_model))\n",
    "set_eval(model)\n",
    "img0, img1 = generate(model, 'CelebAHQ', ood_dataset_name='precomputed', resolution=256)\n",
    "fig, (ax0, ax1) = plt.subplots(ncols=2, dpi=150)\n",
    "ax0.set_axis_off()\n",
    "ax1.set_axis_off()\n",
    "ax0.imshow(img0)\n",
    "ax1.imshow(img1)\n",
    "filename = os.path.join(generation_result_dir, 'celeba256_out.png')\n",
    "img1.save(filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "pytorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
