{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('<anonymized>/hard_label_manifolds')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision.models as models\n",
    "from allmodels import MNIST, load_model, load_mnist_data, \\\n",
    "                      load_cifar10_data, CIFAR10, \\\n",
    "                      load_imagenet_train, load_imagenet_test\n",
    "from config import SEED\n",
    "import os, argparse\n",
    "import numpy as np\n",
    "import json\n",
    "import utils\n",
    "import shutil\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "from config import dataset_to_path, dataset_to_victim_path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute indices\n",
    "\n",
    "This notebook creates class to sample index database for use in experiments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = {\n",
    "    'save_dir': './'\n",
    "}\n",
    "\n",
    "# Expected params for experiment configs\n",
    "state = {\n",
    "    \"seed\": SEED,\n",
    "    \"batch_size\": 1,\n",
    "    \"test_batch_size\": 1,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in [\"MNIST\", \"CIFAR10\", \"Imagenet\"]:\n",
    "    state[\"dataset_path\"] = dataset_to_path[dataset]\n",
    "    \n",
    "    if dataset == \"MNIST\":\n",
    "        _, _, _, gen_dataset = load_mnist_data(state, mode='generator', shuffle_test=False)\n",
    "    elif dataset == \"CIFAR10\":\n",
    "        _, _, _, gen_dataset = load_cifar10_data(state, mode='generator', shuffle_test=False)\n",
    "    elif dataset == \"Imagenet\":\n",
    "        _, gen_dataset = load_imagenet_test(state, normalize=False, shuffle_test=False)  \n",
    "        \n",
    "        \n",
    "    total_n = len(gen_dataset)\n",
    "    ix_database = defaultdict(list)\n",
    "    \n",
    "    print(f\"Start {dataset}...\")\n",
    "    # Start long process of iterating without batch\n",
    "    for i in tqdm(range(total_n)):\n",
    "        _, yi = gen_dataset.__getitem__(i)\n",
    "        if type(yi) is not torch.Tensor:\n",
    "            yi = torch.tensor(yi)\n",
    "        \n",
    "        intyi = int(yi.item())\n",
    "        ix_database[intyi].append(i)\n",
    "        \n",
    "    utils.pickle_write(os.path.join(args['save_dir'], f\"{dataset}_indices.pkl\"), ix_database)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from config import robust_cifar, dataset_to_path, dataset_to_victim_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = {\n",
    "    \"seed\": SEED,\n",
    "    \"batch_size\": 128,\n",
    "    \"test_batch_size\": 128,\n",
    "    \"victim_architecture\": 'resnet50',\n",
    "    \"targeted\": True,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from victim_models.utils import init_classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in robust_cifar:\n",
    "    print(f\"Trying {dataset}.\")\n",
    "    state[\"dataset_path\"] = dataset_to_path[dataset]\n",
    "    state['dataset'] = dataset\n",
    "    state['victim_path'] = dataset_to_victim_path[dataset]\n",
    "    \n",
    "    model_wrapper, gen_dataset, target_loader = init_classifier(state)    \n",
    "    preds = []\n",
    "    actuals = []\n",
    "    \n",
    "    # Target loader is model's shuffled test loader\n",
    "    for xi, yi in tqdm(target_loader):\n",
    "        if type(yi) is not torch.Tensor:\n",
    "            yi = torch.tensor(yi)\n",
    "        \n",
    "        xi, yi = xi.cuda(), yi.cuda()\n",
    "        dec = model_wrapper.predict_label(xi)\n",
    "        \n",
    "        preds.extend(list(dec.cpu().detach().numpy()))\n",
    "        actuals.extend(list(yi.cpu().detach().numpy()))\n",
    "    \n",
    "    acc = accuracy_score(actuals, preds) * 100\n",
    "    print(f\"Accuracy of {dataset}: {acc:.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "alae-data",
   "language": "python",
   "name": "alae-data"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
