{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be38b467-6d51-406c-a1f1-3df6005cabd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "sys.path.append('./src/')\n",
    "\n",
    "\n",
    "from PIL import Image\n",
    "from VAE_trainers import EpochPyroTrainer, AdversarialEpochPyroTrainer, ThresholdPyroTrainer, AdversarialThresholdPyroTrainer\n",
    "from CNN_variants import CNNVAE, CNNCVAE, CNNCSVAENA, CNNCSVAE, CNNHCSVAENA, CNNHCSVAE, CNNSDIVA, CNNCCVAE, CNNDLVAE\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "from tqdm import tqdm, trange\n",
    "from umap import UMAP\n",
    "from torchvision.datasets import CelebA, MNIST\n",
    "from torchvision.transforms import ToTensor\n",
    "\n",
    "\n",
    "import torch, pyro\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import copy, cv2\n",
    "import pyro.optim as opt\n",
    "import pandas as pd\n",
    "\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "\n",
    "# Subset wrapper for attributes\n",
    "class SubsetWrapper(torch.utils.data.Dataset):\n",
    "    \n",
    "    def __init__(self, base_dataset, idxs):\n",
    "        self.base_dataset = base_dataset\n",
    "        self.idxs = idxs\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "\n",
    "        if isinstance(idx, np.ndarray):\n",
    "            elems = list(np.array(self.idxs)[idx])\n",
    "\n",
    "        else:\n",
    "            elems = self.idxs[idx]\n",
    "        \n",
    "        if isinstance(elems, int):\n",
    "            im, lab = self.base_dataset[elems]\n",
    "            lab = lab.unsqueeze(dim=0)\n",
    "\n",
    "\n",
    "            \n",
    "        \n",
    "        else:\n",
    "            im, lab = [], []\n",
    "            for idx in elems:\n",
    "                c_im, c_lab = self.base_dataset[idx]\n",
    "                im.append(c_im)\n",
    "                lab.append(c_lab)\n",
    "             \n",
    "             \n",
    "            im, lab = torch.stack(im), torch.vstack(lab)\n",
    "            \n",
    "        return im, lab, lab \n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.idxs)\n",
    "        \n",
    "hex_colors = [\"#F23E2E\", \"#5888A6\"]\n",
    "cmap = LinearSegmentedColormap.from_list(\"cmap\", hex_colors)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c51a34e2-5f83-484d-9338-e72b047225ca",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Resize Imgs (run only to resize / preprocess images, rename dirs manually after running)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b010849-0828-46e0-9d94-482d4043bb80",
   "metadata": {},
   "outputs": [],
   "source": [
    "from concurrent.futures import ProcessPoolExecutor\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "from skimage import transform, filters, color\n",
    "\n",
    "\n",
    "def sample_img_augment_params(translation_sigma=1.0, scale_sigma=0.01,\n",
    "                              rotation_sigma=0.01, gamma_sigma=0.07,\n",
    "                              contrast_sigma=0.07, hue_sigma=0.0125):\n",
    "    translation = np.random.normal(scale=translation_sigma, size=2)\n",
    "    scale = np.random.normal(loc=1.0, scale=scale_sigma)\n",
    "    rotation = np.random.normal(scale=rotation_sigma)\n",
    "    mu = gamma_sigma**2\n",
    "    gamma = np.random.normal(loc=mu, scale=gamma_sigma)\n",
    "    gamma = np.exp(gamma/np.log(2))\n",
    "    mu = contrast_sigma**2\n",
    "    contrast = np.random.normal(loc=mu, scale=contrast_sigma)\n",
    "    contrast = np.exp(contrast/np.log(2))\n",
    "    hue = np.random.normal(scale=hue_sigma)\n",
    "    return translation, scale, rotation, gamma, contrast, hue\n",
    "    \n",
    "\n",
    "def img_augment(img, translation=0.0, scale=1.0, rotation=0.0, gamma=1.0,\n",
    "                contrast=1.0, hue=0.0, border_mode='constant'):\n",
    "    if not (np.all(np.isclose(translation, [0.0, 0.0])) and\n",
    "            np.isclose(scale, 1.0) and\n",
    "            np.isclose(rotation, 0.0)):\n",
    "        img_center = np.array(img.shape[:2]) / 2.0\n",
    "        scale = (scale, scale)\n",
    "        transf = transform.SimilarityTransform(translation=-img_center)\n",
    "        transf += transform.SimilarityTransform(scale=scale, rotation=rotation)\n",
    "        translation = img_center + translation\n",
    "        transf += transform.SimilarityTransform(translation=translation)\n",
    "        img = transform.warp(img, transf, order=3, mode=border_mode)\n",
    "    if not np.isclose(gamma, 1.0):\n",
    "        img **= gamma\n",
    "    colorspace = 'rgb'\n",
    "    if not np.isclose(contrast, 1.0):\n",
    "        img = color.convert_colorspace(img, colorspace, 'hsv')\n",
    "        colorspace = 'hsv'\n",
    "        img[..., 1:] **= contrast\n",
    "    if not np.isclose(hue, 0.0):\n",
    "        img = color.convert_colorspace(img, colorspace, 'hsv')\n",
    "        colorspace = 'hsv'\n",
    "        img[..., 0] += hue\n",
    "        img[img[..., 0] > 1.0, 0] -= 1.0\n",
    "        img[img[..., 0] < 0.0, 0] += 1.0\n",
    "    img = color.convert_colorspace(img, colorspace, 'rgb')\n",
    "    if np.min(img) < 0.0 or np.max(img) > 1.0:\n",
    "        raise ValueError('Invalid values in output image.')\n",
    "    return img\n",
    "\n",
    "\n",
    "def _resize(args):\n",
    "    img, rescale_size, bbox = args\n",
    "    img = img[bbox[0]:bbox[1], bbox[2]:bbox[3]]\n",
    "    # Smooth image before resize to avoid moire patterns\n",
    "    scale = img.shape[0] / float(rescale_size)\n",
    "    sigma = np.sqrt(scale) / 2.0\n",
    "    img = filters.gaussian(img, sigma=sigma)\n",
    "    img = transform.resize(img, (rescale_size, rescale_size, 3), order=3)\n",
    "    img = (img*255).astype(np.uint8)\n",
    "    return img\n",
    "\n",
    "\n",
    "def _resize_augment(args):\n",
    "    img, rescale_size, bbox = args\n",
    "    augment_params = sample_img_augment_params(\n",
    "        translation_sigma=2.00, scale_sigma=0.01, rotation_sigma=0.01,\n",
    "        gamma_sigma=0.05, contrast_sigma=0.05, hue_sigma=0.01\n",
    "    )\n",
    "    img = img_augment(img, *augment_params, border_mode='constant')\n",
    "    img = _resize((img, rescale_size, bbox))\n",
    "    return img\n",
    "\n",
    "\n",
    "img_size=64 \n",
    "bbox=(40, 218-30, 15, 178-15)\n",
    "\n",
    "\n",
    "# Define paths\n",
    "src_dir = 'data/celeba/img_align_celeba/'  # The dataset needs to be downloaded separately\n",
    "dst_dir = 'data/celeba/img_align_celeba_64_aug'  # Rename the outputted folder after conversion for correct loading!\n",
    "\n",
    "# Ensure output directory exists\n",
    "os.makedirs(dst_dir, exist_ok=True)\n",
    "\n",
    "# List of filenames\n",
    "base_im_dir = os.listdir(src_dir)\n",
    "\n",
    "# Function to resize and save one image\n",
    "def resize_and_save(fname):\n",
    "    src_path = os.path.join(src_dir, fname)\n",
    "    dst_path = os.path.join(dst_dir, fname)\n",
    "    try:\n",
    "        im = np.array(Image.open(src_path)) / 255\n",
    "        im = _resize_augment([im, img_size, bbox])\n",
    "        Image.fromarray(im).save(dst_path)\n",
    "\n",
    "    except Exception as e:\n",
    "        return f\"Failed on {fname}: {e}\"\n",
    "\n",
    "# Parallel execution with progress bar\n",
    "with ProcessPoolExecutor() as executor:\n",
    "    list(tqdm(executor.map(resize_and_save, base_im_dir), total=len(base_im_dir)))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b738eae-7a93-4f2d-946c-844838665318",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Data (run after setting up data dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05237b65-0b6d-4c66-8e06-d5de89277f92",
   "metadata": {},
   "outputs": [],
   "source": [
    "splits = pd.read_csv('./data/celeba/list_eval_partition.txt', sep=' ', header=None)\n",
    "splits.columns = ['fname', 'split']\n",
    "splits['split'] = splits['split'].astype('category')\n",
    "splits['split'] = splits['split'].cat.rename_categories(['train', 'test', 'eval'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1f8dc03-d62f-41f4-a66d-0f51b4055867",
   "metadata": {},
   "outputs": [],
   "source": [
    "attrs = pd.read_csv('./data/celeba/list_attr_celeba.txt', header=1, skipinitialspace=True, sep=' ')\n",
    "attrs.columns = ['fname'] + list(attrs.columns)[:-1]\n",
    "\n",
    "attrs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1b755ce-6a68-4f4a-8d11-fdfb8f442bbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "attr_pd = pd.read_csv('./data/celeba/list_attr_celeba.txt', header=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56d0e0c4-beeb-49f4-824c-1f351a01ce30",
   "metadata": {},
   "outputs": [],
   "source": [
    "def glasses_transform(target):\n",
    "    return target[15].type(torch.float32)\n",
    "\n",
    "dataset_glasses = CelebA(root=\"./data\", download=False, transform=ToTensor() ,target_transform=glasses_transform, split='all')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d4d4096-8cef-450f-96ae-2021a6935489",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size=64\n",
    "num_workers=16\n",
    "\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "\n",
    "## Training data creation\n",
    "\n",
    "attrs_train = attrs[attrs['fname'].isin(splits[splits['split'] == 'train']['fname'])]\n",
    "glasses_train_idx =  attrs_train[attrs_train['Eyeglasses'] == 1].index\n",
    "glasses_train_count = len(glasses_train_idx)\n",
    "glasses_negative_train_idx = attrs_train[attrs_train['Eyeglasses'] == -1].sample(2*glasses_train_count).index\n",
    "glasses_train_idx = list(glasses_train_idx) + list(glasses_negative_train_idx)\n",
    "train_set = SubsetWrapper(dataset_glasses, glasses_train_idx)\n",
    "\n",
    "\n",
    "## Test data creation\n",
    "attrs_test = attrs[attrs['fname'].isin(splits[splits['split'] == 'test']['fname'])]\n",
    "glasses_test_idx =  attrs_test[attrs_test['Eyeglasses'] == 1].index\n",
    "glasses_test_count = len(glasses_test_idx)\n",
    "glasses_negative_test_idx = attrs_test[attrs_test['Eyeglasses'] == -1].sample(2*glasses_test_count).index\n",
    "glasses_test_idx = list(glasses_test_idx) + list(glasses_negative_test_idx)\n",
    "\n",
    "test_set = SubsetWrapper(dataset_glasses, glasses_test_idx)\n",
    "\n",
    "\n",
    "## Set loaders\n",
    "train_loader, test_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, num_workers=num_workers, timeout=100), torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers, timeout=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07a5dbde-07b9-4a5c-9e06-84d3ddc5eae2",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## CSVAENA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75a654a2-b482-4899-95fe-2282aaa84f93",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "csvaena = CNNCSVAENA((64,64), 3, [1], latent_dim=2048, w_dim=2, channels=[64,128,256], repeats=[2,1,1], cnn_arch='conv+pool', num_layers=0, recon_weight=1, z_kl_weight=1e-4, kernel_size=5)\n",
    "csvaena_trainer = ThresholdPyroTrainer(0, 50, csvaena, train_loader, test_loader, opt.AdamW({\"lr\": 1e-4}))\n",
    "csvaena_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "807b79b7-73b9-4ed7-ad9e-34783a4f7f27",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_glasses = csvaena_trainer.predictive(*csvaena_trainer._send_args_to_device(csvaena_trainer.test_loader.dataset[:100], csvaena_trainer.device))\n",
    "z_s_glasses = preds_glasses['z'][0].cpu()\n",
    "w_s_glasses = preds_glasses['w'][0].cpu()\n",
    "recons_glasses = preds_glasses['rec'][0, 0].cpu()\n",
    "\n",
    "\n",
    "preds_ng = csvaena_trainer.predictive(*csvaena_trainer._send_args_to_device(csvaena_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100], csvaena_trainer.device))\n",
    "z_s_ng = preds_ng['z'][0].cpu() \n",
    "w_s_ng = preds_ng['w'][0].cpu() \n",
    "recons_ng = preds_ng['rec'][0, 0].cpu()\n",
    "\n",
    "z_s = torch.vstack((z_s_glasses, z_s_ng))\n",
    "w_s = torch.vstack((w_s_glasses, w_s_ng))\n",
    "recons = torch.vstack((recons_glasses, recons_ng))\n",
    "y_s = torch.vstack((csvaena_trainer.test_loader.dataset[:100][1], csvaena_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100][1]))\n",
    "origs = torch.vstack((csvaena_trainer.test_loader.dataset[:100][0], csvaena_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100][0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f11bca7-92d4-4f85-9d40-ae9bd20cc833",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = test_set[i][0], test_set[i][1]\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = test_set[glasses_test_count + i][0], test_set[glasses_test_count + i][1]\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eaed480c-3572-4e5a-972d-207e2bd4f6ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = recons[i], test_set[i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = recons[100 + i], test_set[glasses_test_count + i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f6cef62-57ea-42dd-8395-eaba667e601d",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## CSVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "976dcc7b-9ed8-4422-b148-3542acfe0241",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "\n",
    "csvae = CNNCSVAE((64,64), 3, [1], latent_dim=2048, w_dim=2, channels=[64,128,256], repeats=[2,1,1], cnn_arch='conv+pool', num_layers=0, recon_weight=1e3, z_kl_weight=1e-4, kernel_size=5)\n",
    "csvae_trainer = AdversarialThresholdPyroTrainer(0, 50, 1, 1, csvae, train_loader, test_loader, opt.AdamW({\"lr\": 1e-4}))\n",
    "csvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74d7c1be-e4b4-44ee-9d11-feaa7ff2b26c",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_glasses = csvae_trainer.predictive(*csvae_trainer._send_args_to_device(csvae_trainer.test_loader.dataset[:100], csvae_trainer.device))\n",
    "z_s_glasses = preds_glasses['z'][0].cpu()\n",
    "w_s_glasses = preds_glasses['w'][0].cpu()\n",
    "recons_glasses = preds_glasses['rec'][0, 0].cpu()\n",
    "\n",
    "\n",
    "preds_ng = csvae_trainer.predictive(*csvae_trainer._send_args_to_device(csvae_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100], csvae_trainer.device))\n",
    "z_s_ng = preds_ng['z'][0].cpu() \n",
    "w_s_ng = preds_ng['w'][0].cpu() \n",
    "recons_ng = preds_ng['rec'][0, 0].cpu()\n",
    "\n",
    "z_s = torch.vstack((z_s_glasses, z_s_ng))\n",
    "w_s = torch.vstack((w_s_glasses, w_s_ng))\n",
    "recons = torch.vstack((recons_glasses, recons_ng))\n",
    "y_s = torch.vstack((csvae_trainer.test_loader.dataset[:100][1], csvae_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100][1]))\n",
    "origs = torch.vstack((csvae_trainer.test_loader.dataset[:100][0], csvae_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100][0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5e4196c-8919-4411-a062-347394827b63",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = test_set[i][0], test_set[i][1]\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = test_set[glasses_test_count + i][0], test_set[glasses_test_count + i][1]\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa9f58c2-cfd6-4ed6-97ac-48645415b576",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = recons[i], test_set[i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = recons[100 + i], test_set[glasses_test_count + i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1063ec2-6556-47dd-8550-0e155da7e325",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## HCSVAENA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33cdc762-5cec-4727-a724-ce2693bb9d6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "hcsvaena = CNNHCSVAENA((64,64), 3, [1], latent_dim=2048, w_dim=2, channels=[64,128,256], repeats=[2,1,1], cnn_arch='conv+pool', num_layers=0, recon_weight=1e3, z_kl_weight=1e-4, kernel_size=5)\n",
    "hcsvaena_trainer = ThresholdPyroTrainer(0, 50, hcsvaena, train_loader, test_loader, opt.AdamW({\"lr\": 1e-4}))\n",
    "hcsvaena_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85e974ba-33fc-430f-9b32-9f45dfd26b87",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_glasses = hcsvaena_trainer.predictive(*hcsvaena_trainer._send_args_to_device(hcsvaena_trainer.test_loader.dataset[:100], hcsvaena_trainer.device))\n",
    "z_s_glasses = preds_glasses['z'][0].cpu()\n",
    "w_s_glasses = preds_glasses['w'][0].cpu()\n",
    "recons_glasses = preds_glasses['rec'][0, 0].cpu()\n",
    "\n",
    "\n",
    "preds_ng = hcsvaena_trainer.predictive(*hcsvaena_trainer._send_args_to_device(hcsvaena_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100], hcsvaena_trainer.device))\n",
    "z_s_ng = preds_ng['z'][0].cpu() \n",
    "w_s_ng = preds_ng['w'][0].cpu() \n",
    "recons_ng = preds_ng['rec'][0, 0].cpu()\n",
    "\n",
    "z_s = torch.vstack((z_s_glasses, z_s_ng))\n",
    "w_s = torch.vstack((w_s_glasses, w_s_ng))\n",
    "recons = torch.vstack((recons_glasses, recons_ng))\n",
    "y_s = torch.vstack((hcsvaena_trainer.test_loader.dataset[:100][1], hcsvaena_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100][1]))\n",
    "y_s = torch.vstack((hcsvaena_trainer.test_loader.dataset[:100][0], hcsvaena_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100][0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fff4c882-fe73-4718-b213-308937e77c8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = test_set[i][0], test_set[i][1]\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = test_set[glasses_test_count + i][0], test_set[glasses_test_count + i][1]\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8400fead-3ac6-457b-affb-528b29247cd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = recons[i], test_set[i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = recons[100 + i], test_set[glasses_test_count + i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ef48dd5-f09e-470a-a16b-134c58e3f047",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## HCSVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2015f9d-f361-4c89-ac4c-03f08b0dd452",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "hcsvae = CNNHCSVAE((64,64), 3, [1], latent_dim=2048, w_dim=2, channels=[64,128,256], repeats=[2,1,1], cnn_arch='conv+pool', num_layers=0, recon_weight=1e4, z_kl_weight=1e-4, kernel_size=5)\n",
    "hcsvae_trainer = AdversarialThresholdPyroTrainer(0, 50, 1, 1, hcsvae, train_loader, test_loader, opt.AdamW({\"lr\": 1e-4}))\n",
    "hcsvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c9e9dde-0b92-4a04-8673-080d90ecddb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_glasses = hcsvae_trainer.predictive(*hcsvae_trainer._send_args_to_device(hcsvae_trainer.test_loader.dataset[:100], hcsvae_trainer.device))\n",
    "z_s_glasses = preds_glasses['z'][0].cpu()\n",
    "w_s_glasses = preds_glasses['w'][0].cpu()\n",
    "recons_glasses = preds_glasses['rec'][0, 0].cpu()\n",
    "\n",
    "\n",
    "preds_ng = hcsvae_trainer.predictive(*hcsvae_trainer._send_args_to_device(hcsvae_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100], hcsvae_trainer.device))\n",
    "z_s_ng = preds_ng['z'][0].cpu() \n",
    "w_s_ng = preds_ng['w'][0].cpu() \n",
    "recons_ng = preds_ng['rec'][0, 0].cpu()\n",
    "\n",
    "z_s = torch.vstack((z_s_glasses, z_s_ng))\n",
    "w_s = torch.vstack((w_s_glasses, w_s_ng))\n",
    "recons = torch.vstack((recons_glasses, recons_ng))\n",
    "y_s = torch.vstack((hcsvae_trainer.test_loader.dataset[:100][1], hcsvae_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100][1]))\n",
    "origs = torch.vstack((hcsvae_trainer.test_loader.dataset[:100][0], hcsvae_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100][0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fe9ee53-76f3-4824-9717-e6b7fda6ad7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = test_set[i][0], test_set[i][1]\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = test_set[glasses_test_count + i][0], test_set[glasses_test_count + i][1]\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d088d05-659a-4ea7-82fd-3664f9839672",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = recons[i], test_set[i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = recons[100 + i], test_set[glasses_test_count + i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd228da9-de02-40f1-aa0e-93c56368f466",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## DIVA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e80c4d6-dac7-4cb6-a859-12b2c393bf0d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "diva = CNNSDIVA((64,64), 3, [1], latent_dim=2048, w_dim=2, channels=[64,128,256], repeats=[2,1,1], cnn_arch='conv+pool', num_layers=0, recon_weight=1e5, kl_weight=1e-4, kernel_size=5)\n",
    "diva_trainer = ThresholdPyroTrainer(0, 50, diva, train_loader, test_loader, opt.AdamW({\"lr\": 1e-4}))\n",
    "diva_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc8a97ac-7ea5-4643-9732-3d98cb32d827",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_glasses = diva_trainer.predictive(*diva_trainer._send_args_to_device(diva_trainer.test_loader.dataset[:100], diva_trainer.device))\n",
    "z_s_glasses = preds_glasses['z'][0].cpu()\n",
    "w_s_glasses = preds_glasses['w'][0].cpu()\n",
    "recons_glasses = preds_glasses['rec'][0, 0].cpu()\n",
    "\n",
    "\n",
    "preds_ng = diva_trainer.predictive(*diva_trainer._send_args_to_device(diva_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100], diva_trainer.device))\n",
    "z_s_ng = preds_ng['z'][0].cpu() \n",
    "w_s_ng = preds_ng['w'][0].cpu() \n",
    "recons_ng = preds_ng['rec'][0, 0].cpu()\n",
    "\n",
    "z_s = torch.vstack((z_s_glasses, z_s_ng))\n",
    "w_s = torch.vstack((w_s_glasses, w_s_ng))\n",
    "recons = torch.vstack((recons_glasses, recons_ng))\n",
    "y_s = torch.vstack((diva_trainer.test_loader.dataset[:100][1], diva_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100][1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be01c9bf-7312-4b5b-94fb-3752fba468f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = test_set[i][0], test_set[i][1]\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = test_set[glasses_test_count + i][0], test_set[glasses_test_count + i][1]\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44c016f7-7585-46ec-b2fc-694fcc452126",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = recons[i], test_set[i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = recons[100 + i], test_set[glasses_test_count + i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1dbc9ea0-76b8-4431-b8a7-94b156577f55",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## CCVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67fe32d6-73fc-49e7-b1ab-68f2333324ff",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "ccvae = CNNCCVAE((64,64), 3, [1], latent_dim=2048, w_dim=2, channels=[64,128,256], repeats=[2,1,1], cnn_arch='conv+pool', num_layers=0, recon_weight=1e5, kl_weight=1e-4, kernel_size=5)\n",
    "ccvae_trainer = ThresholdPyroTrainer(0, 50, ccvae, train_loader, test_loader, opt.AdamW({\"lr\": 1e-4}))\n",
    "ccvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4c07244-c7e4-45db-8ddb-aee15a24caed",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_glasses = ccvae_trainer.predictive(*ccvae_trainer._send_args_to_device(ccvae_trainer.test_loader.dataset[:100], ccvae_trainer.device))\n",
    "z_s_glasses = preds_glasses['z'][0].cpu()\n",
    "w_s_glasses = preds_glasses['w'][0].cpu()\n",
    "recons_glasses = preds_glasses['rec'][0, 0].cpu()\n",
    "\n",
    "\n",
    "preds_ng = ccvae_trainer.predictive(*ccvae_trainer._send_args_to_device(ccvae_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100], ccvae_trainer.device))\n",
    "z_s_ng = preds_ng['z'][0].cpu() \n",
    "w_s_ng = preds_ng['w'][0].cpu() \n",
    "recons_ng = preds_ng['rec'][0, 0].cpu()\n",
    "\n",
    "z_s = torch.vstack((z_s_glasses, z_s_ng))\n",
    "w_s = torch.vstack((w_s_glasses, w_s_ng))\n",
    "recons = torch.vstack((recons_glasses, recons_ng))\n",
    "y_s = torch.vstack((ccvae_trainer.test_loader.dataset[:100][1], ccvae_trainer.test_loader.dataset[glasses_test_count:glasses_test_count+100][1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb1d58ea-67f5-4b1a-b2b6-8943b67b6c48",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = test_set[i][0], test_set[i][1]\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = test_set[glasses_test_count + i][0], test_set[glasses_test_count + i][1]\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53d848ce-9e77-4708-a989-82711d6edaf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = recons[i], test_set[i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = recons[100 + i], test_set[glasses_test_count + i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c35f916-d610-4e71-819e-6e559c5b520e",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## DISCoVeR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e61a3f83-16ed-4e5e-a130-f9f12ecd3e75",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "dlvae = CNNDLVAE((64,64), 3, [1], latent_dim=2048, w_dim=2048, channels=[64,128,256], repeats=[2,1,1], cnn_arch='conv+pool', num_layers=0, kernel_size=5, recon_weight=1e6, recon_weight_z=1e5, w_kl_weight=1e-4, z_kl_weight=1e-4, adversarial_weight=2e3, learnable_prior=False)\n",
    "dlvae_trainer = AdversarialThresholdPyroTrainer(0, 50, 1, 2, dlvae, train_loader, test_loader, opt.AdamW({\"lr\": 1e-4}))\n",
    "dlvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cc59fc1-8b4f-489e-b8f3-aa106c10a170",
   "metadata": {},
   "outputs": [],
   "source": [
    "window=0\n",
    "samp=100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46d00cce-f30a-47ed-9b36-4e16835f85de",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_glasses = dlvae_trainer.predictive(*dlvae_trainer._send_args_to_device(dlvae_trainer.test_loader.dataset[window:samp+window], dlvae_trainer.device))\n",
    "z_s_glasses = preds_glasses['z'][0].cpu()\n",
    "w_s_glasses = preds_glasses['w'][0].cpu()\n",
    "recons_w_glasses = preds_glasses['rec_w'][0, 0].cpu()\n",
    "recons_z_glasses = preds_glasses['rec_z'][0, 0].cpu()\n",
    "\n",
    "\n",
    "preds_ng = dlvae_trainer.predictive(*dlvae_trainer._send_args_to_device(dlvae_trainer.test_loader.dataset[glasses_test_count+window:glasses_test_count+window+samp], dlvae_trainer.device))\n",
    "z_s_ng = preds_ng['z'][0].cpu() \n",
    "w_s_ng = preds_ng['w'][0].cpu() \n",
    "recons_w_ng = preds_ng['rec_w'][0, 0].cpu()\n",
    "recons_z_ng = preds_ng['rec_z'][0, 0].cpu()\n",
    "\n",
    "\n",
    "z_s = torch.vstack((z_s_glasses, z_s_ng))\n",
    "w_s = torch.vstack((w_s_glasses, w_s_ng))\n",
    "recons_w = torch.vstack((recons_w_glasses, recons_w_ng))\n",
    "recons_z = torch.vstack((recons_z_glasses, recons_z_ng))\n",
    "y_s = torch.vstack((dlvae_trainer.test_loader.dataset[window:window+samp][1], dlvae_trainer.test_loader.dataset[glasses_test_count+window:glasses_test_count+window+samp][1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb8952b3-1af1-4000-876d-4e91da133761",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = test_set[i+window][0], test_set[i+window][1]\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = test_set[glasses_test_count+window + i][0], test_set[glasses_test_count+window + i][1]\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3617b680-9176-4d8d-98de-89839fe3e410",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = recons_w[i], test_set[i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = recons_w[100 + i], test_set[glasses_test_count + i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45d92e46-4244-4c79-90ba-c129647b2d9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "for i in range(10):\n",
    "    im, label = recons_z[i], test_set[i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')\n",
    "\n",
    "\n",
    "for i in range(10, 20):\n",
    "    im, label = recons_z[100 + i], test_set[glasses_test_count + i][1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    ax[i//10][i%10].imshow(im)\n",
    "    ax[i//10][i%10].axis('off')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
