{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8d5d56ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import yaml\n",
    "from ml_collections import ConfigDict\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, Subset\n",
    "\n",
    "from bin_cp.experiments.image_utils.cifar_resnet import ResNet\n",
    "from bin_cp.experiments.image_utils.architectures import get_architecture\n",
    "from bin_cp.experiments.image_utils.image_datasets import get_dataset\n",
    "from bin_cp.helpers.lightner import ModelManager, Output\n",
    "from bin_cp.robust.smoothing import standard_l2_norm\n",
    "from bin_cp.helpers.storage import smooth_prediction_filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "05677f22",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = \"cifar10\"\n",
    "model_sigma = 0.25\n",
    "n_datapoints = 2048\n",
    "smoothing_sigma = 0.25\n",
    "n_samples = 10000\n",
    "recompute = False\n",
    "save=True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "33b11eda",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "dataset size = 10000\n",
      "dataset size = 2048\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "general_config = yaml.safe_load(open(\"../../bin_cp/conf/general.yaml\", \"r\"))\n",
    "conf = ConfigDict(general_config[\"general\"])\n",
    "default_models = general_config[\"models\"]\n",
    "model_name = default_models[dataset_name]\n",
    "\n",
    "# Loading model\n",
    "model_file = os.path.join(conf.models_dir, dataset_name, model_name, f\"noise_{model_sigma}\", \"checkpoint.pth.tar\")\n",
    "model_dict = torch.load(model_file)\n",
    "model = get_architecture(model_dict[\"arch\"], dataset_name)\n",
    "model.load_state_dict(model_dict[\"state_dict\"])\n",
    "model_obj = ModelManager(model, device=device)\n",
    "\n",
    "# Loading dataset\n",
    "dataset = get_dataset('cifar10', 'test', root=conf.dataset_dir)\n",
    "print(f\"dataset size = {len(dataset)}\")\n",
    "subset_indices = list(range(0, n_datapoints, ))\n",
    "dataset = Subset(dataset, subset_indices)\n",
    "print(f\"dataset size = {len(dataset)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be7812ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from attacks import PGD_L2 as PGDSmooth\n",
    "from torchattacks import PGDRS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1711dee4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ba0cd056",
   "metadata": {},
   "outputs": [],
   "source": [
    "smoothing_function=lambda x: standard_l2_norm(x, sigma=smoothing_sigma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "756f1420",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
