{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "root = os.path.join('..', '..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import dataclasses\n",
    "from abc import ABC, abstractmethod\n",
    "from typing import Any, Dict, List, Literal, Optional, Tuple\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import patchworklib as pw\n",
    "import torch\n",
    "from torch import Tensor\n",
    "\n",
    "from utils.classifiers import OneHiddenNet\n",
    "from utils.decision_map import (get_axis_vec, get_decision_map,\n",
    "                                get_inputs_for_decision_map)\n",
    "from utils.fig import Axes, Figure\n",
    "from utils.utils import freeze, gpu"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Global Setting & Variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Figure.set_tex()\n",
    "Figure.set_high_dpi()\n",
    "device = gpu(0)\n",
    "resolution = 500\n",
    "limit = 3.4\n",
    "ylabels = ('Standard', 'Adversarial', 'Noise')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Figure.set_font_scale(1.4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclasses.dataclass\n",
    "class DataUtil:\n",
    "    in_dim: int\n",
    "    hidden_dim: int\n",
    "    n_sample: int\n",
    "    n_noise_sample: int\n",
    "    norm: Literal['L0', 'L2', 'Linf']\n",
    "    mode: Literal['uniform', 'gauss']\n",
    "    perturbation_constraint: float\n",
    "    seed: int\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        self.d = self._load_data()\n",
    "\n",
    "    def _load_data(self) -> Dict[str, Any]:\n",
    "        fname = f'{self.in_dim}_{self.hidden_dim}_{self.n_sample}_{self.n_noise_sample}' + \\\n",
    "                f'_{self.norm}_{self.mode}_{self.perturbation_constraint}_{self.seed}'\n",
    "        path = os.path.join(root, 'artificial', fname)\n",
    "        return torch.load(path, map_location='cpu')\n",
    "    \n",
    "    def _define_classifier(self) -> OneHiddenNet:\n",
    "        classifier = OneHiddenNet(self.in_dim, self.hidden_dim)\n",
    "        classifier.to(device)\n",
    "\n",
    "        freeze(classifier)\n",
    "        classifier.eval()\n",
    "        \n",
    "        return classifier\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def get_decision_maps_and_acc_list(self) -> Tuple[Tensor, Tensor]:\n",
    "        axis_vec_1, axis_vec_2 = get_axis_vec(self.d['classifier']['linear.weight'])\n",
    "        inputs = get_inputs_for_decision_map(axis_vec_1, axis_vec_2, resolution, limit)\n",
    "        inputs = inputs.to(device)\n",
    "\n",
    "        classifier = self._define_classifier()\n",
    "\n",
    "        decision_maps = torch.empty(3, resolution, resolution)\n",
    "        acc_list = torch.empty(3)\n",
    "\n",
    "        weight_keys = ('classifier', 'adv_classifier', 'noise_classifier')\n",
    "        acc_keys = ('acc', 'adv_acc_for_natural', 'noise_acc_for_natural')\n",
    "\n",
    "        for i, (weight_key, acc_key) in enumerate(zip(weight_keys, acc_keys)):\n",
    "            classifier.load_state_dict(self.d[weight_key])\n",
    "            decision_maps[i] = get_decision_map(classifier, inputs)\n",
    "            acc_list[i] = self.d[acc_key]\n",
    "\n",
    "        return decision_maps, acc_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Figure Utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Superclass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Block(ABC):\n",
    "    suptitle: str\n",
    "    ylabels: Optional[Tuple[str, str, str]]\n",
    "    variables: Any\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        self.titles = self.variables\n",
    "\n",
    "    def _embed_decision_map_into_brick(\n",
    "        self,\n",
    "        brick: pw.Brick, \n",
    "        decision_map: Tensor, \n",
    "        acc: float, \n",
    "        title: Optional[float] = None, \n",
    "        ylabel: Optional[str] = None,\n",
    "    ) -> None:\n",
    "        ax = Axes(brick)\n",
    "        ax.imshow(decision_map, True)\n",
    "        ax.set_xlabel( f'{int(100*acc)}' + r'\\%' )\n",
    "        if title is not None:\n",
    "            ax.set_title(f'{title:,}')\n",
    "        if ylabel is not None:\n",
    "            ax.set_ylabel(ylabel)\n",
    "\n",
    "    def _construct_block(\n",
    "        self,\n",
    "        decision_map_2dlist: Tensor,\n",
    "        acc_2dlist: Tensor, \n",
    "        suptitle: str,\n",
    "        titles: Tuple[float, float, float, float], \n",
    "        ylabels: Optional[Tuple[str, str, str]],\n",
    "    ) -> pw.Bricks:\n",
    "        \n",
    "        row_bricks_list: List[pw.Bricks] = []\n",
    "        for i, (decision_map_list, acc_list) in enumerate(zip(decision_map_2dlist, acc_2dlist)):\n",
    "\n",
    "            col_brick_list: List[pw.Brick] = []\n",
    "            for j, (decision_map, acc) in enumerate(zip(decision_map_list, acc_list)):\n",
    "\n",
    "                brick = pw.Brick()\n",
    "                col_brick_list.append(brick)\n",
    "                \n",
    "                title = titles[j] if i == 0 else None\n",
    "                ylabel = ylabels[i] if j == 0 and ylabels is not None else None\n",
    "\n",
    "                self._embed_decision_map_into_brick(brick, decision_map, acc.item(), title, ylabel)\n",
    "\n",
    "            row_bricks = pw.stack(col_brick_list, 0.05, '|')\n",
    "            row_bricks_list.append(row_bricks)\n",
    "\n",
    "        bricks: pw.Bricks = pw.stack(row_bricks_list, 0.05, '/')\n",
    "        bricks.set_suptitle(suptitle)\n",
    "        return bricks\n",
    "    \n",
    "    @abstractmethod\n",
    "    def _define_artificial_instance(self, var: Any) -> DataUtil:\n",
    "        pass\n",
    "    \n",
    "    def __call__(self) -> pw.Bricks:\n",
    "        decision_map_block = torch.empty(3, 4, resolution, resolution)\n",
    "        acc_block = torch.empty(3, 4)\n",
    "\n",
    "        for i, var in enumerate(self.variables):\n",
    "            n = self._define_artificial_instance(var)\n",
    "            decision_maps, acc_list = n.get_decision_maps_and_acc_list()\n",
    "            decision_map_block[:, i] = decision_maps\n",
    "            acc_block[:, i] = acc_list\n",
    "\n",
    "        return self._construct_block(decision_map_block, acc_block, self.suptitle, self.titles, self.ylabels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input Dimension Block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclasses.dataclass\n",
    "class InputDimensionBlock(Block):\n",
    "    in_dims: Tuple[int, int, int, int]\n",
    "    hidden_dim: int\n",
    "    n_sample: int\n",
    "    n_noise_sample: int\n",
    "    norm: Literal['L0', 'L2', 'Linf']\n",
    "    mode: Literal['uniform', 'gauss']\n",
    "    perturbation_constraints: Tuple[float, float, float, float]\n",
    "    seed: int\n",
    "    ylabels: Optional[Tuple[str, str, str]] = None\n",
    "    suptitle: str = r'Input dimension $d$'\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        self.variables = [(i, j) for i, j in zip(self.in_dims, self.perturbation_constraints)]\n",
    "        self.titles = self.in_dims\n",
    "\n",
    "    def _define_artificial_instance(self, var: Tuple[int, float]) -> DataUtil:\n",
    "        return DataUtil(var[0], self.hidden_dim, self.n_sample, self.n_noise_sample, \n",
    "                        self.norm, self.mode, var[1], self.seed)            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Natural Sample Block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclasses.dataclass\n",
    "class NaturalSampleBlock(Block):\n",
    "    in_dim: int\n",
    "    hidden_dim: int\n",
    "    n_samples: Tuple[int, int, int, int]\n",
    "    n_noise_sample: int\n",
    "    norm: Literal['L0', 'L2', 'Linf']\n",
    "    mode: Literal['uniform', 'gauss']\n",
    "    perturbation_constraint: float\n",
    "    seed: int\n",
    "    ylabels: Optional[Tuple[str, str, str]] = None\n",
    "    suptitle: str = r'Natural sample $N$'\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        self.variables = self.n_samples\n",
    "        super().__post_init__()\n",
    "\n",
    "    def _define_artificial_instance(self, var: int) -> DataUtil:\n",
    "        return DataUtil(self.in_dim, self.hidden_dim, var, self.n_noise_sample, \n",
    "                        self.norm, self.mode, self.perturbation_constraint, self.seed)            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Noise Sample Block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclasses.dataclass\n",
    "class NoiseSampleBlock(Block):\n",
    "    in_dim: int\n",
    "    hidden_dim: int\n",
    "    n_sample: int\n",
    "    n_noise_samples: Tuple[int, int, int, int]\n",
    "    norm: Literal['L0', 'L2', 'Linf']\n",
    "    mode: Literal['uniform', 'gauss']\n",
    "    perturbation_constraint: float\n",
    "    seed: int\n",
    "    ylabels: Optional[Tuple[str, str, str]] = None\n",
    "    suptitle: str = r'(Adversarial) noise sample $N^{\\mathrm{adv}}$'\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        self.variables = self.n_noise_samples\n",
    "        super().__post_init__()\n",
    "\n",
    "    def _define_artificial_instance(self, var: int) -> DataUtil:\n",
    "        return DataUtil(self.in_dim, self.hidden_dim, self.n_sample, var, \n",
    "                        self.norm, self.mode, self.perturbation_constraint, self.seed)            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Perturbation Constraint Block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclasses.dataclass\n",
    "class PerturbationConstraintBlock(Block):\n",
    "    in_dim: int\n",
    "    hidden_dim: int\n",
    "    n_sample: int\n",
    "    n_noise_samples: int\n",
    "    norm: Literal['L0', 'L2', 'Linf']\n",
    "    mode: Literal['uniform', 'gauss']\n",
    "    perturbation_constraints: Tuple[float, float, float, float]\n",
    "    seed: int\n",
    "    ylabels: Optional[Tuple[str, str, str]] = None\n",
    "    suptitle: Optional[str] = None\n",
    "\n",
    "    def __post_init__(self) -> None:\n",
    "        if self.suptitle is None:\n",
    "            if self.norm in ['L2', 'Linf']:\n",
    "                self.suptitle = r'Perturbation constraint $\\epsilon$'\n",
    "            else:\n",
    "                self.suptitle = r'Modified pixel ratio $d_\\delta/d$'\n",
    "        self.variables = self.perturbation_constraints\n",
    "        super().__post_init__()\n",
    "\n",
    "    def _define_artificial_instance(self, var: int) -> DataUtil:\n",
    "        return DataUtil(self.in_dim, self.hidden_dim, self.n_sample, self.n_noise_samples, \n",
    "                        self.norm, self.mode, var, self.seed)            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# All Blocks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def all_blocks(\n",
    "    in_dim: int,\n",
    "    in_dims: Tuple[int, int, int, int],\n",
    "    perturbation_constraints_along_with_in_dims: Tuple[float, float, float, float],\n",
    "    hidden_dim: int,\n",
    "    n_sample: int,\n",
    "    n_samples: Tuple[int, int, int, int],\n",
    "    n_noise_sample: int,\n",
    "    n_noise_samples: Tuple[int, int, int, int],\n",
    "    norm: Literal['L0', 'L2', 'Linf'],\n",
    "    mode: Literal['uniform', 'gauss'],\n",
    "    perturbation_constraint: float,\n",
    "    perturbation_constraints: Tuple[float, float, float, float],\n",
    "    seed: int,\n",
    ") -> None:\n",
    "    in_dim_block = InputDimensionBlock(in_dims, hidden_dim, n_sample, n_noise_sample, \n",
    "                                       norm, mode, perturbation_constraints_along_with_in_dims, seed, ylabels)()\n",
    "    noise_sample_block = NoiseSampleBlock(in_dim, hidden_dim, n_sample, n_noise_samples, \n",
    "                                          norm, mode, perturbation_constraint, seed)()\n",
    "    natural_sample_block = NaturalSampleBlock(in_dim, hidden_dim, n_samples, n_noise_sample, \n",
    "                                              norm, mode, perturbation_constraint, seed, ylabels)()\n",
    "    perturbation_constraint_block = PerturbationConstraintBlock(in_dim, hidden_dim, n_sample, n_noise_sample, \n",
    "                                                                norm, mode, perturbation_constraints, seed)()\n",
    "    \n",
    "    top = pw.stack([in_dim_block, noise_sample_block], margin=0.2, operator='|')\n",
    "    bottom = pw.stack([natural_sample_block, perturbation_constraint_block], margin=0.2, operator='|')\n",
    "    all = pw.stack([top, bottom], margin=0.2, operator='/')\n",
    "\n",
    "    path = os.path.join(root, 'figs', f'decision_maps_{norm}_{mode}.pdf')\n",
    "    all.savefig(path, bbox_inches='tight', pad_inches=0.025)\n",
    "\n",
    "    pw.clear()\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# L0 / Uniform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_dim = 10000\n",
    "in_dims = (100, 500, 1000, 10000)\n",
    "perturbation_constraints_along_with_in_dims = (0.05, 0.05, 0.05, 0.05)\n",
    "hidden_dim = 1000\n",
    "n_sample = 1000\n",
    "n_samples = (1000, 2000, 5000, 10000)\n",
    "n_noise_sample = 10000\n",
    "n_noise_samples = (1, 10, 100, 10000)\n",
    "norm = 'L0'\n",
    "mode = 'uniform'\n",
    "perturbation_constraint = 0.05\n",
    "perturbation_constraints = (0.0001, 0.0004, 0.001, 0.05)\n",
    "seed = 5\n",
    "\n",
    "all_blocks(\n",
    "    in_dim,\n",
    "    in_dims,\n",
    "    perturbation_constraints_along_with_in_dims,\n",
    "    hidden_dim,\n",
    "    n_sample,\n",
    "    n_samples,\n",
    "    n_noise_sample,\n",
    "    n_noise_samples,\n",
    "    norm,\n",
    "    mode,\n",
    "    perturbation_constraint,\n",
    "    perturbation_constraints,\n",
    "    seed,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# L0 / Gauss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_dim = 10000\n",
    "in_dims = (100, 500, 1000, 10000)\n",
    "perturbation_constraints_along_with_in_dims = (0.05, 0.05, 0.05, 0.05)\n",
    "hidden_dim = 1000\n",
    "n_sample = 1000\n",
    "n_samples = (1000, 2000, 5000, 10000)\n",
    "n_noise_sample = 10000\n",
    "n_noise_samples = (1, 10, 100, 10000)\n",
    "norm = 'L0'\n",
    "mode = 'gauss'\n",
    "perturbation_constraint = 0.05\n",
    "perturbation_constraints = (0.0001, 0.0004, 0.001, 0.05)\n",
    "seed = 2\n",
    "\n",
    "all_blocks(\n",
    "    in_dim,\n",
    "    in_dims,\n",
    "    perturbation_constraints_along_with_in_dims,\n",
    "    hidden_dim,\n",
    "    n_sample,\n",
    "    n_samples,\n",
    "    n_noise_sample,\n",
    "    n_noise_samples,\n",
    "    norm,\n",
    "    mode,\n",
    "    perturbation_constraint,\n",
    "    perturbation_constraints,\n",
    "    seed,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# L2 / Uniform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_dim = 10000\n",
    "in_dims = (100, 500, 1000, 10000)\n",
    "perturbation_constraints_along_with_in_dims = (0.078, 0.17, 0.24, 0.78)\n",
    "hidden_dim = 1000\n",
    "n_sample = 1000\n",
    "n_samples = (1000, 2000, 5000, 10000)\n",
    "n_noise_sample = 10000\n",
    "n_noise_samples = (1, 10, 100, 10000)\n",
    "norm = 'L2'\n",
    "mode = 'uniform'\n",
    "perturbation_constraint = 0.78\n",
    "perturbation_constraints = (0.01, 0.05, 0.1, 0.78)\n",
    "seed = 5\n",
    "\n",
    "all_blocks(\n",
    "    in_dim,\n",
    "    in_dims,\n",
    "    perturbation_constraints_along_with_in_dims,\n",
    "    hidden_dim,\n",
    "    n_sample,\n",
    "    n_samples,\n",
    "    n_noise_sample,\n",
    "    n_noise_samples,\n",
    "    norm,\n",
    "    mode,\n",
    "    perturbation_constraint,\n",
    "    perturbation_constraints,\n",
    "    seed,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# L2 / Gauss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_dim = 10000\n",
    "in_dims = (100, 500, 1000, 10000)\n",
    "perturbation_constraints_along_with_in_dims = (0.078, 0.17, 0.24, 0.78)\n",
    "hidden_dim = 1000\n",
    "n_sample = 1000\n",
    "n_samples = (1000, 2000, 5000, 10000)\n",
    "n_noise_sample = 10000\n",
    "n_noise_samples = (1, 10, 100, 10000)\n",
    "norm = 'L2'\n",
    "mode = 'gauss'\n",
    "perturbation_constraint = 0.78\n",
    "perturbation_constraints = (0.01, 0.05, 0.1, 0.78)\n",
    "seed = 2\n",
    "\n",
    "all_blocks(\n",
    "    in_dim,\n",
    "    in_dims,\n",
    "    perturbation_constraints_along_with_in_dims,\n",
    "    hidden_dim,\n",
    "    n_sample,\n",
    "    n_samples,\n",
    "    n_noise_sample,\n",
    "    n_noise_samples,\n",
    "    norm,\n",
    "    mode,\n",
    "    perturbation_constraint,\n",
    "    perturbation_constraints,\n",
    "    seed,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Linf / Uniform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_dim = 10000\n",
    "in_dims = (100, 500, 1000, 10000)\n",
    "perturbation_constraints_along_with_in_dims = (0.03, 0.03, 0.03, 0.03)\n",
    "hidden_dim = 1000\n",
    "n_sample = 1000\n",
    "n_samples = (1000, 2000, 5000, 10000)\n",
    "n_noise_sample = 10000\n",
    "n_noise_samples = (1, 10, 100, 10000)\n",
    "norm = 'Linf'\n",
    "mode = 'uniform'\n",
    "perturbation_constraint = 0.03\n",
    "perturbation_constraints = (0.001, 0.005, 0.01, 0.03)\n",
    "seed = 5\n",
    "\n",
    "all_blocks(\n",
    "    in_dim,\n",
    "    in_dims,\n",
    "    perturbation_constraints_along_with_in_dims,\n",
    "    hidden_dim,\n",
    "    n_sample,\n",
    "    n_samples,\n",
    "    n_noise_sample,\n",
    "    n_noise_samples,\n",
    "    norm,\n",
    "    mode,\n",
    "    perturbation_constraint,\n",
    "    perturbation_constraints,\n",
    "    seed,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Linf / Gauss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_dim = 10000\n",
    "in_dims = (100, 500, 1000, 10000)\n",
    "perturbation_constraints_along_with_in_dims = (0.03, 0.03, 0.03, 0.03)\n",
    "hidden_dim = 1000\n",
    "n_sample = 1000\n",
    "n_samples = (1000, 2000, 5000, 10000)\n",
    "n_noise_sample = 10000\n",
    "n_noise_samples = (1, 10, 100, 10000)\n",
    "norm = 'Linf'\n",
    "mode = 'gauss'\n",
    "perturbation_constraint = 0.03\n",
    "perturbation_constraints = (0.001, 0.005, 0.01, 0.03)\n",
    "seed = 2\n",
    "\n",
    "all_blocks(\n",
    "    in_dim,\n",
    "    in_dims,\n",
    "    perturbation_constraints_along_with_in_dims,\n",
    "    hidden_dim,\n",
    "    n_sample,\n",
    "    n_samples,\n",
    "    n_noise_sample,\n",
    "    n_noise_samples,\n",
    "    norm,\n",
    "    mode,\n",
    "    perturbation_constraint,\n",
    "    perturbation_constraints,\n",
    "    seed,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Two Blocks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Figure.set_font_scale(1.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def two_blocks(\n",
    "    in_dim: int,\n",
    "    in_dims: Tuple[int, int, int, int],\n",
    "    perturbation_constraints_along_with_in_dims: Tuple[float, float, float, float],\n",
    "    hidden_dim: int,\n",
    "    n_sample: int,\n",
    "    n_noise_sample: int,\n",
    "    n_noise_samples: Tuple[int, int, int, int],\n",
    "    norm: Literal['L2', 'L0'],\n",
    "    mode: Literal['uniform', 'gauss'],\n",
    "    perturbation_constraint: float,\n",
    "    seed: int,\n",
    "):\n",
    "    in_dim_block = InputDimensionBlock(in_dims, hidden_dim, n_sample, n_noise_sample, \n",
    "                                       norm, mode, perturbation_constraints_along_with_in_dims, seed, ylabels)()\n",
    "    noise_sample_block = NoiseSampleBlock(in_dim, hidden_dim, n_sample, n_noise_samples, \n",
    "                                          norm, mode, perturbation_constraint, seed)()\n",
    "    \n",
    "    b = pw.stack([in_dim_block, noise_sample_block], margin=0.2, operator='|')\n",
    "\n",
    "    path = os.path.join(root, 'figs', f'decision_maps_{norm}_{mode}_two.pdf')\n",
    "    b.savefig(path, bbox_inches='tight', pad_inches=0.025)\n",
    "\n",
    "    pw.clear()\n",
    "    Figure.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_dim = 10000\n",
    "in_dims = (100, 500, 1000, 10000)\n",
    "perturbation_constraints_along_with_in_dims = (0.078, 0.17, 0.24, 0.78)\n",
    "hidden_dim = 1000\n",
    "n_sample = 1000\n",
    "n_samples = (1000, 2000, 5000, 10000)\n",
    "n_noise_sample = 10000\n",
    "n_noise_samples = (1, 10, 100, 10000)\n",
    "norm = 'L2'\n",
    "mode = 'uniform'\n",
    "perturbation_constraint = 0.78\n",
    "perturbation_constraints = (0.01, 0.05, 0.1, 0.78)\n",
    "seed = 5\n",
    "\n",
    "two_blocks(\n",
    "    in_dim,\n",
    "    in_dims,\n",
    "    perturbation_constraints_along_with_in_dims,\n",
    "    hidden_dim,\n",
    "    n_sample,\n",
    "    n_noise_sample,\n",
    "    n_noise_samples,\n",
    "    norm,\n",
    "    mode,\n",
    "    perturbation_constraint,\n",
    "    seed,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
