{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7b2987d-69eb-4f4f-a1c6-84f40f33906f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "sys.path.append('./src/')\n",
    "\n",
    "from VAE_trainers import EpochPyroTrainer, AdversarialEpochPyroTrainer, ThresholdPyroTrainer, AdversarialThresholdPyroTrainer\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "from tqdm import tqdm, trange\n",
    "from umap import UMAP\n",
    "from CNN_variants import CNNDecoder, CNNEncoder, GaussianCNNEncoder, CNNVAE, CNNCVAE, CNNCSVAENA, CNNCSVAE, CNNHCSVAENA, CNNHCSVAE, CNNSDIVA, CNNCCVAE, CNNDLVAE\n",
    "import pyro.distributions as dist\n",
    "\n",
    "\n",
    "import torch, pyro\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import copy, cv2\n",
    "import pyro.optim as opt\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision import transforms\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "# Define the hex codes for the colormap\n",
    "hex_colors = [\"#F23E2E\", \"#5888A6\"]\n",
    "\n",
    "# Create the colormap\n",
    "cmap = LinearSegmentedColormap.from_list(\"cmap\", hex_colors)\n",
    "\n",
    "demo_epochs = 200\n",
    "noise = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acc98276-fdfd-4f0c-8bfe-298e06599fea",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94164ee3-7876-46df-a6dd-25f7385a2822",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set, test_set = MNIST('./data', train=True, download=True, transform=transforms.ToTensor()), MNIST('./data', train=False, download=True, transform=transforms.ToTensor()) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0ee15f4-3bbf-4f15-9178-5bb923a15053",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.utils.data as utils\n",
    "\n",
    "def colorize_and_makedata(mnist_set, batch_size=128, test=False):\n",
    "    ims, labels, unnoised_labels, numbers, = [], [], [], []\n",
    "    \n",
    "    for im_base, label in tqdm(mnist_set, total=len(mnist_set)):\n",
    "        num = label\n",
    "        im_base = cv2.cvtColor(np.einsum('ijk -> jki', im_base.numpy()), cv2.COLOR_GRAY2BGR)\n",
    "\n",
    "        labs, css = (0,1), ([1], [0])\n",
    "\n",
    "        for label, cs in zip(labs, css):\n",
    "            im = im_base.copy()\n",
    "            mask = np.any(im != 0, axis=-1)\n",
    "            masked_portion = im[mask]\n",
    "            masked_portion[:,cs] = 0.\n",
    "            im[mask] = masked_portion\n",
    "        \n",
    "        \n",
    "            \n",
    "            ims.append(np.einsum('jki -> ijk', im))\n",
    "\n",
    "            if np.random.uniform() < noise:\n",
    "                label_exp = abs(label-1) # swap\n",
    "\n",
    "            else:\n",
    "                label_exp = label\n",
    "            \n",
    "            labels.append(label_exp)\n",
    "            unnoised_labels.append(label)\n",
    "            numbers.append(num)\n",
    "\n",
    "    ims, labels, unnoised_labels, numbers = torch.FloatTensor(np.array(ims)), torch.FloatTensor(np.array(labels)).reshape(-1,1), torch.FloatTensor(np.array(unnoised_labels)).reshape(-1,1), torch.FloatTensor(np.array(numbers)).reshape(-1,1)\n",
    "    extras = torch.hstack((numbers, unnoised_labels))\n",
    "    \n",
    "    dataset = utils.TensorDataset(ims[:len(ims)//4], labels[:len(ims)//4], extras[:len(ims)//4])  # Cut sizes even more for faster run in demo\n",
    "    \n",
    "    if not test:\n",
    "        loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=batch_size)\n",
    "\n",
    "    else:\n",
    "        loader = torch.utils.data.DataLoader(dataset, shuffle=False, batch_size=batch_size)\n",
    "\n",
    "    \n",
    "    return dataset, loader\n",
    "\n",
    "\n",
    "(train_set, train_loader), (test_set, test_loader) = colorize_and_makedata(train_set), colorize_and_makedata(test_set, test=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fe782ee-fa67-4687-90ff-8c291f681424",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "to_draw = list(range(10))\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "while i < 10:\n",
    "    im, label, _ = test_set[j]\n",
    "    true_label = _[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0 and _[0] == to_draw[0]:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        to_draw.pop(0)\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "\n",
    "to_draw = list(range(10))\n",
    "\n",
    "\n",
    "while i < 20:\n",
    "    im, label, _ = test_set[j]\n",
    "    true_label = _[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1 and _[0] == to_draw[0]:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        to_draw.pop(0)\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c74ee148-8d2e-4189-a8c0-da0afd94352d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "to_draw = list(range(10))\n",
    "\n",
    "\n",
    "while i < 10:\n",
    "    im, label, extras = test_set[j]\n",
    "    true_label = extras[1]\n",
    "    \n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    im_common = im.copy()\n",
    "    im_common[..., 1] = im_common[..., 1] / 2\n",
    "    im_common[..., 0] = im_common[..., 1] \n",
    "\n",
    "\n",
    "    if true_label > 0 and extras[0] == to_draw[0]:\n",
    "        ax[i//10][i%10].imshow(im_common.reshape(28,28,3))\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        to_draw.pop(0)\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "\n",
    "to_draw = list(range(10))\n",
    "\n",
    "\n",
    "while i < 20:\n",
    "    im, label, extras = test_set[j]\n",
    "    true_label = extras[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    im_common = im.copy()\n",
    "    im_common[..., 0] = im_common[..., 0] / 2\n",
    "    im_common[..., 1] = im_common[..., 0] \n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "    if true_label < 1 and extras[0] == to_draw[0]:\n",
    "        ax[i//10][i%10].imshow(im_common.reshape(28,28,3))\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        to_draw.pop(0)\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c892db50-d8d7-4f77-8cd7-cc535d08a00f",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## CSVAE - No Adv."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2e5bca9-dc91-484a-94a3-163cbabdc631",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "csvaena = CNNCSVAENA((28,28), 3, [1], latent_dim=20, w_dim=2, channels=[32,64,128], repeats=[2,1,1], cnn_arch='conv+pool', recon_weight=1, z_kl_weight=1e-4, w_kl_weight=1)\n",
    "csvaena_trainer = EpochPyroTrainer(demo_epochs, csvaena, train_loader, test_loader, opt.AdamW({\"lr\": 1e-4}))\n",
    "csvaena_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d481c810-a95c-4aae-a1b8-4e8acb29f6dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = csvaena_trainer.predictive(*csvaena_trainer._send_args_to_device(csvaena_trainer.test_loader.dataset[:1000], csvaena_trainer.device))\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90f714d8-2ea1-476e-89f3-d0a683082202",
   "metadata": {},
   "outputs": [],
   "source": [
    "recon_test = torch.stack([csvaena_trainer.best_model.decoder(torch.concatenate((z_s, torch.stack([torch.normal(0, 0.1, (w_s.shape[-1],)) if np.random.choice([0,1], 1) else torch.normal(3, 1, (w_s.shape[-1],)) for elem in test_set[:1000][1]])), dim=-1).cuda()).detach().cpu() for i in range(1000)]).mean(dim=0)\n",
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "while i < 10:\n",
    "    im, label, num = recon_test[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "while i < 20:\n",
    "    im, label, num = recon_test[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "455cccfb-046c-4483-afcf-c0a0288d4da4",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "while i < 10:\n",
    "    im, label, num = recons[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "while i < 20:\n",
    "    im, label, num = recons[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61bb85ac-d15d-4b9f-ac85-c52474ccb869",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "umap_zs = reducer.fit_transform(z_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39523cba-70da-400d-b9a4-c42d442a2c3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(umap_zs[:, 0], umap_zs[:, 1], c=test_set[:1000][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d85f1387-b90f-417e-b5af-34a754220228",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(umap_zs[:, 0], umap_zs[:, 1], c=test_set[:1000][2][:,0].numpy(), cmap='tab10', alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "958ff6fd-722f-45cf-bd30-2229be396df7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:1000][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fd2848f-e71e-4166-a002-30c7ec6605ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:1000][2][:,0].numpy(), cmap='tab10', alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f9c838a-8121-434c-88c6-b7770627ddb8",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## CSVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7dfea87-0ae0-40c0-b3b7-1265340ce548",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "\n",
    "csvae = CNNCSVAE((28,28), 3, [1], latent_dim=20, w_dim=2, channels=[32,64,128], repeats=[2,1,1], cnn_arch='conv+pool', recon_weight=1, z_kl_weight=1e-4, w_kl_weight=1, adversarial_weight=1)\n",
    "csvae_trainer = AdversarialEpochPyroTrainer(demo_epochs, 1, 1, csvae, train_loader, test_loader, opt.AdamW({\"lr\": 1e-4}))\n",
    "csvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf2f395f-98b6-472e-8f0a-40a240930e4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = csvae_trainer.predictive(*csvae_trainer._send_args_to_device(csvae_trainer.test_loader.dataset[:1000], csvae_trainer.device))\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2fc2c46-d12f-4596-ab1e-88af3bbbf0f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "recon_test = torch.stack([csvae_trainer.best_model.decoder(torch.concatenate((z_s, torch.stack([torch.normal(0, 0.1, (w_s.shape[-1],)) if np.random.choice([0,1], 1) else torch.normal(3, 1, (w_s.shape[-1],)) for elem in test_set[:1000][1]])), dim=-1).cuda()).detach().cpu() for i in range(1000)]).mean(dim=0)\n",
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "while i < 10:\n",
    "    im, label, num = recon_test[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "while i < 20:\n",
    "    im, label, num = recon_test[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a0b1fc6-898b-464c-a932-cdb1750dc517",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "while i < 10:\n",
    "    im, label, num = recons[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "while i < 20:\n",
    "    im, label, num = recons[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c99736cc-50e5-443c-9581-56f76caff260",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "umap_zs = reducer.fit_transform(z_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6bf2bf1-7bbf-4190-ba92-2a7309c7f2ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(umap_zs[:, 0], umap_zs[:, 1], c=test_set[:1000][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4027d536-0af9-42f0-bc5b-29fe46cb37c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(umap_zs[:, 0], umap_zs[:, 1], c=test_set[:1000][2][:,0].numpy(), cmap='tab10', alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71ac71c2-3069-4c90-bd95-730d6d462e10",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:1000][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d32e824f-62be-40a8-9da6-7741868b04a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:1000][2][:,0].numpy(), cmap='tab10', alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ee05282-04d4-421b-a700-d2a6d59eadaa",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## HCSVAE - No Adv."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a078b6c-6d93-4378-befe-8a49e56fb1e0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "hcsvaena = CNNHCSVAENA((28,28), 3, [1], latent_dim=20, w_dim=2, channels=[32,64,128], repeats=[2,1,1], cnn_arch='conv+pool', recon_weight=1e3, z_kl_weight=1e-4, w_kl_weight=1)\n",
    "hcsvaena_trainer = EpochPyroTrainer(demo_epochs, hcsvaena, train_loader, test_loader, opt.AdamW({\"lr\": 1e-4}))\n",
    "hcsvaena_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfc36194-6a4f-440e-917c-044afe3e4b6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = hcsvaena_trainer.predictive(*hcsvaena_trainer._send_args_to_device(hcsvaena_trainer.test_loader.dataset[:1000], hcsvaena_trainer.device))\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "578c3035-d95f-4abf-9cc9-7d82291f7f50",
   "metadata": {},
   "outputs": [],
   "source": [
    "recon_test = torch.stack([hcsvaena_trainer.best_model.decoder(hcsvaena_trainer.best_model.decoder_rho(torch.concatenate((z_s, torch.stack([torch.normal(0, 0.1, (w_s.shape[-1],)) if np.random.choice([0,1], 1) else torch.normal(3, 1, (w_s.shape[-1],)) for elem in test_set[:1000][1]])), dim=-1).cuda())[0]).detach().cpu() for i in range(1000)]).mean(dim=0)\n",
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "while i < 10:\n",
    "    im, label, num = recon_test[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "while i < 20:\n",
    "    im, label, num = recon_test[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21adc685-f10f-4ac9-ab88-bf14a112947d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "while i < 10:\n",
    "    im, label, num = recons[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "while i < 20:\n",
    "    im, label, num = recons[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24468d7c-9097-4653-ab9d-40a3fad95934",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "umap_zs = reducer.fit_transform(z_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40e1eec1-65f3-4085-87aa-06dcd979ffab",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(umap_zs[:, 0], umap_zs[:, 1], c=test_set[:1000][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e76dc0b-80c9-4616-a0df-4dd3dc0faebc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(umap_zs[:, 0], umap_zs[:, 1], c=test_set[:1000][2][:,0].numpy(), cmap='tab10', alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44baaef8-20ba-48cc-8c69-1e542066c752",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:1000][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccc1a6d0-ed3e-438c-a0e7-9e8338d2b71a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:1000][2][:,0].numpy(), cmap='tab10', alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aaab4f88-67d6-4609-be96-756ee57477d1",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## HCSVAE "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "299548b6-b99e-4480-b11d-13194ab7c22c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "hcsvae = CNNHCSVAE((28,28), 3, [1], latent_dim=20, w_dim=2, channels=[32,64,128], repeats=[2,1,1], cnn_arch='conv+pool', recon_weight=1e4, z_kl_weight=1e-4, w_kl_weight=1, adversarial_weight=1)\n",
    "hcsvae_trainer = AdversarialEpochPyroTrainer(demo_epochs, 1, 1, hcsvae, train_loader, test_loader, opt.AdamW({\"lr\": 1e-4}))\n",
    "hcsvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53804585-995c-4fa1-9d52-c1350ad4c5ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = hcsvae_trainer.predictive(*hcsvae_trainer._send_args_to_device(hcsvae_trainer.test_loader.dataset[:1000], hcsvae_trainer.device))\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "rho_s = preds['rho'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74e0eafc-611b-46ca-a26a-affba7452adf",
   "metadata": {},
   "outputs": [],
   "source": [
    "recon_test = torch.stack([hcsvae_trainer.best_model.decoder(hcsvae_trainer.best_model.decoder_rho(torch.concatenate((z_s, torch.stack([torch.normal(0, 0.1, (w_s.shape[-1],)) if np.random.choice([0,1], 1) else torch.normal(3, 1, (w_s.shape[-1],)) for elem in test_set[:1000][1]])), dim=-1).cuda())[0]).detach().cpu() for i in range(1000)]).mean(dim=0)\n",
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "while i < 10:\n",
    "    im, label, num = recon_test[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "while i < 20:\n",
    "    im, label, num = recon_test[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23dc6567-25a1-467c-90b1-cc077bd162bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "while i < 10:\n",
    "    im, label, num = recons[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "while i < 20:\n",
    "    im, label, num = recons[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30621cd1-c58c-4acd-b4d5-08e7a7d3ac0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "umap_zs = reducer.fit_transform(z_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df08033c-c7c2-48a4-a402-1b1ad01af721",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(umap_zs[:, 0], umap_zs[:, 1], c=test_set[:1000][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bf12fe8-c888-4f30-9131-83b9e29d5b15",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(umap_zs[:, 0], umap_zs[:, 1], c=test_set[:1000][2][:,0].numpy(), cmap='tab10', alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee27e561-4b47-43a4-926d-8f1bc98134e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:1000][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e939e75-0b5c-42bc-9d1d-c8366bcd6ad1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:1000][2][:,0].numpy(), cmap='tab10', alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4af3d87f-8322-4d24-aa41-81add8e2ac15",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## DIVA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee035f27-c8e5-40d1-a887-56254fb31768",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "diva = CNNSDIVA((28,28), 3, [1], latent_dim=20, w_dim=2, channels=[32,64,128], repeats=[2,1,1], cnn_arch='conv+pool', recon_weight=1, kl_weight=1e-4)\n",
    "diva_trainer = EpochPyroTrainer(demo_epochs, diva, train_loader, test_loader, opt.AdamW({\"lr\": 1e-4}))\n",
    "diva_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f71325e-69bc-497c-9495-bb9e75060dbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = diva_trainer.predictive(*diva_trainer._send_args_to_device(diva_trainer.test_loader.dataset[:1000], diva_trainer.device))\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d799a473-7792-47dd-918d-8f18d531fd1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "recon_test = torch.stack([diva_trainer.best_model.decoder(torch.concatenate((z_s, dist.Normal(*diva_trainer.best_model.prior_w(torch.FloatTensor(np.random.choice([0,1], (z_s.shape[0], 1))).cuda())).sample().detach().cpu()), dim=-1).cuda()).detach().cpu() for i in trange(1000)]).mean(dim=0)\n",
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "while i < 10:\n",
    "    im, label, num = recon_test[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "while i < 20:\n",
    "    im, label, num = recon_test[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46affdc9-feb4-44c2-99b8-fdbb437c9f94",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "while i < 10:\n",
    "    im, label, num = recons[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "while i < 20:\n",
    "    im, label, num = recons[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e2b40fa-ad87-43ce-b8d7-d6649f5f5f93",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "umap_zs = reducer.fit_transform(z_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf7bca3c-8e14-472c-b634-3b0314b3634b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(umap_zs[:, 0], umap_zs[:, 1], c=test_set[:1000][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8ed6621-4cb0-4f68-b2d6-5566e6e95e01",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(umap_zs[:, 0], umap_zs[:, 1], c=test_set[:1000][2][:,0].numpy(), cmap='tab10', alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a117794-1fa4-4125-b779-06b706b2999f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:1000][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e03a93ea-5970-4ac8-aaf3-578389f16883",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:1000][2][:,0].numpy(), cmap='tab10', alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5aab137f-946b-4406-bce8-dc313b9ec62f",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## CCVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09f677a5-f958-490a-aa63-94baea5ae85f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "ccvae = CNNCCVAE((28,28), 3, [1], latent_dim=20, w_dim=2, channels=[32,64,128], repeats=[2,1,1], cnn_arch='conv+pool', recon_weight=1, kl_weight=1e-4)\n",
    "ccvae_trainer = EpochPyroTrainer(demo_epochs, ccvae, train_loader, test_loader, opt.AdamW({\"lr\": 1e-4}))\n",
    "ccvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6834b83b-3dc3-47b8-b119-e7c60dad7ffb",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = ccvae_trainer.predictive(*ccvae_trainer._send_args_to_device(ccvae_trainer.test_loader.dataset[:1000], ccvae_trainer.device))\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b9e6c4b-ecc5-45a1-ab24-bcae484d636b",
   "metadata": {},
   "outputs": [],
   "source": [
    "recon_test = torch.stack([ccvae_trainer.best_model.decoder(torch.concatenate((z_s, dist.Normal(*ccvae_trainer.best_model.prior_w(torch.FloatTensor(np.random.choice([0,1], (z_s.shape[0], 1))).cuda())).sample().detach().cpu()), dim=-1).cuda()).detach().cpu() for i in trange(1000)]).mean(dim=0)\n",
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "while i < 10:\n",
    "    im, label, num = recon_test[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "while i < 20:\n",
    "    im, label, num = recon_test[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dabff361-e84d-4e31-9232-6fe2438ecfb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "while i < 10:\n",
    "    im, label, num = recons[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "while i < 20:\n",
    "    im, label, num = recons[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55318474-393c-4e53-9059-c0272b060257",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "umap_zs = reducer.fit_transform(z_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a1b0b3c-bb19-44db-a7d8-ca153209aaf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(umap_zs[:, 0], umap_zs[:, 1], c=test_set[:1000][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "782c304f-8e75-4696-97d0-13aaa1a35736",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(umap_zs[:, 0], umap_zs[:, 1], c=test_set[:1000][2][:,0].numpy(), cmap='tab10', alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac726a8c-a603-4c20-b18e-72bc4260fb87",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:1000][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3767d64c-065b-4bd8-a1a9-dfa389337cf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:1000][2][:,0].numpy(), cmap='tab10', alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69e27321-ee65-48b0-9d2d-af397561f6c9",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Ours"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7736c16-d6aa-46e6-abd5-bb21f5f29651",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "dlvae = CNNDLVAE((28,28), 3, [1], latent_dim=20, w_dim=2, channels=[32,64,128], repeats=[2,1,1], cnn_arch='conv+pool', recon_weight=5e-1, recon_weight_z=5e-1, w_kl_weight=1e-4, z_kl_weight=1e-4, adversarial_weight=1e-1, learnable_prior=True)\n",
    "dlvae_trainer = AdversarialEpochPyroTrainer(demo_epochs, 1, 1, dlvae, train_loader, test_loader, opt.AdamW({\"lr\": 1e-4}))\n",
    "dlvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fed1689-9293-4275-ade4-ade54609ed4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = dlvae_trainer.predictive(*dlvae_trainer._send_args_to_device(dlvae_trainer.test_loader.dataset[:1000], dlvae_trainer.device))\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons_z = preds['rec_z'][0, 0].cpu()\n",
    "recons_w = preds['rec_w'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c417d33-8cc5-49d6-a68f-7929dadafc73",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "\n",
    "\n",
    "while i < 10:\n",
    "    im, label, num = recons_z[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "\n",
    "while i < 20:\n",
    "    im, label, num = recons_z[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a775b11b-5117-41a4-8de9-436f0be70d76",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2,10, figsize=(15,3), gridspec_kw = {'wspace':0, 'hspace':-0.02})\n",
    "\n",
    "i,j = 0,0\n",
    "to_draw = list(range(10))\n",
    "\n",
    "\n",
    "while i < 10:\n",
    "    im, label, num = recons_w[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "    if true_label > 0 and num[0]==to_draw[0]:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "        #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        to_draw.pop(0)\n",
    "        i += 1\n",
    "\n",
    "    j += 1\n",
    "\n",
    "to_draw = list(range(10))\n",
    "\n",
    "while i < 20:\n",
    "    im, label, num = recons_w[j], test_set[j][1], test_set[j][2]\n",
    "    true_label = num[1]\n",
    "    im = torch.clamp(im, 0, 1)\n",
    "    im = np.einsum('ijk -> jki', im)\n",
    "\n",
    "\n",
    "    if true_label < 1 and num[0]==to_draw[0]:\n",
    "        ax[i//10][i%10].imshow(im.reshape(28,28,3))\n",
    "       #ax[i//10][i%10].set_title(f'{int(num)}')\n",
    "        ax[i//10][i%10].axis('off')\n",
    "        to_draw.pop(0)\n",
    "        i += 1\n",
    "\n",
    "    j += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8f9459e-9777-4d6d-88bd-d6061ed516bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "umap_zs = reducer.fit_transform(z_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42bb0e04-0035-4cb2-89c1-54f3b3e3c5b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(umap_zs[:, 0], umap_zs[:, 1], c=test_set[:1000][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20344618-5dde-42be-a6e0-0871010eb6ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:1000][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7dbf0a7-175c-4bab-92a9-ffd724875b0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(umap_zs[:, 0], umap_zs[:, 1], c=test_set[:1000][2][:,0].numpy(), cmap='tab10', alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17f69188-82d2-47f2-b0c3-ef1e0042f7e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:1000][2][:,0].numpy(), cmap='tab10', alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
