{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd194bc5-0032-4c7a-957f-f9a622d90391",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "sys.path.append('./src/')\n",
    "\n",
    "from VAE_variants import VAE, CVAE, CSVAENA, CSVAE, HCSVAENA, HCSVAE, DLVAE, SDIVA, CCVAE\n",
    "from VAE_trainers import EpochPyroTrainer, AdversarialEpochPyroTrainer, ThresholdPyroTrainer, AdversarialThresholdPyroTrainer\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "from sklearn.naive_bayes import GaussianNB\n",
    "from metrics import MINE\n",
    "from tqdm import trange\n",
    "import pyro.distributions as dist\n",
    "\n",
    "\n",
    "import torch, pyro\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pyro.optim as opt\n",
    "import seaborn as sns\n",
    "\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "demo_epochs=5\n",
    "\n",
    "cmap = LinearSegmentedColormap.from_list(\"cmap\", [\"#DB6D00\", \"#006DDB\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ef2d5df-f7f4-4b5d-8a75-57fd9fd8cfa1",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81c341ac-fb66-4999-ba38-263ea0fee4d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.utils.data as utils\n",
    "\n",
    "#################### DATASET PARAMS #########################################################################################\n",
    "n=10000  # Number of points for the roll\n",
    "batch_size=64 # Loader batch size\n",
    "p=0.3\n",
    "##################################################################################################################################\n",
    "from sklearn.datasets import make_swiss_roll\n",
    "\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "xs, _ = make_swiss_roll(n)\n",
    "xs = torch.FloatTensor(xs)\n",
    "ys = (xs[:, 1] < 10).type(torch.long)\n",
    "\n",
    "ys_real = ys.float().reshape(-1,1) \n",
    "\n",
    "for i in range(len(ys)):\n",
    "    if np.random.uniform() < p:\n",
    "        ys[i] = abs(ys[i]-1)\n",
    "\n",
    "ys = ys.float().reshape(-1,1)\n",
    "xs = (xs - np.array([xs[:,0].min() - 5, -5, xs[:,2].min() - 5])).float() # Shift to make entries non-negative\n",
    "dataset = utils.TensorDataset(xs, ys, ys) # Duplicate ys, needed for workflow\n",
    "train_set = dataset\n",
    "train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=batch_size)\n",
    "\n",
    "\n",
    "xs, _ = make_swiss_roll(n)\n",
    "xs = torch.FloatTensor(xs)\n",
    "ys = (xs[:, 1] < 10).type(torch.long)\n",
    "\n",
    "for i in range(len(ys)):\n",
    "    if np.random.uniform() < p:\n",
    "        ys[i] = abs(ys[i]-1)\n",
    "\n",
    "ys = ys.float().reshape(-1,1)\n",
    "xs = (xs - np.array([xs[:,0].min() - 5, -5, xs[:,2].min() - 5])).float() # Shift to make entries non-negative\n",
    "dataset_test = utils.TensorDataset(xs, ys, ys) # Duplicate ys, needed for workflow\n",
    "test_set = dataset_test\n",
    "test_loader = torch.utils.data.DataLoader(test_set, shuffle=False, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "388f0168-fe9b-42e3-aee9-7458a55f5323",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_set_marg = test_set[:][0].clone()\n",
    "test_set_marg[:, 1] = test_set_marg[:, 1].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db43ed1c-9a97-4b71-b3f4-43edab56378d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15,15))\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "scatter = ax.scatter(dataset[:][0][:,0], dataset[:][0][:,1], dataset[:][0][:,2], c=train_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "\"\"\"\n",
    "elems = list(scatter.legend_elements())\n",
    "legend = ax.legend(*elems,\n",
    "               loc=\"lower left\", \n",
    "               title=\"Classes\",\n",
    "               title_fontsize=20,\n",
    "               fontsize=18,\n",
    "                markerscale=3)\n",
    "\"\"\"\n",
    "\n",
    "ax.axis('off')\n",
    "\n",
    "ax.set_xlabel('X', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_ylabel('Y', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_zlabel('Z', weight='bold', labelpad=7, fontsize=18)\n",
    "\n",
    "ax.set_xlim(0,30)\n",
    "ax.set_ylim(0,30)\n",
    "ax.set_zlim(0,30)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8553767-b42c-438b-b417-2fe7fe5fc7a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15,15))\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "scatter = ax.scatter(dataset[:][0][:,0], torch.FloatTensor([dataset[:][0][:,1].mean().item()]*len(dataset)).reshape(-1,1), dataset[:][0][:,2], c=train_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "\"\"\"\n",
    "elems = list(scatter.legend_elements())\n",
    "legend = ax.legend(*elems,\n",
    "               loc=\"lower left\", \n",
    "               title=\"Classes\",\n",
    "               title_fontsize=20,\n",
    "               fontsize=18,\n",
    "                markerscale=3)\n",
    "\"\"\"\n",
    "\n",
    "ax.axis('off')\n",
    "\n",
    "ax.set_xlabel('X', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_ylabel('Y', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_zlabel('Z', weight='bold', labelpad=7, fontsize=18)\n",
    "\n",
    "ax.set_xlim(0,30)\n",
    "ax.set_ylim(0,30)\n",
    "ax.set_zlim(0,30)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53c79ccc-3e00-46c6-8e21-d02e42f46ccf",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(dataset[:][0][:,0], dataset[:][0][:,2], c=train_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bbda67d-0a8b-44fd-b6d4-f9f5fb091bfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(dataset[:][0][:,1], c=train_set[:][1].numpy(), cmap=cmap, alpha=0.6, orient='y')\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6194cd49-7d66-4526-a49c-a60e7e8bc622",
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_data_score = GaussianNB().fit(train_set[:][0], train_set[:][1].squeeze(-1)).score(test_set[:][0], test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc - Combined: {np.round(orig_data_score, 2)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc53e247-94d0-4a7d-ac07-4d89a6fcadaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_data_score_common = GaussianNB(priors=(0.49, 0.51)).fit(train_set[:][0][:,[0,2]], train_set[:][1].squeeze(-1)).score(test_set[:][0][:,[0,2]], test_set[:][1].squeeze(-1)) # sample more points if needed\n",
    "print(f'Bayes classifier acc - Common: {np.round(orig_data_score_common, 2)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90f2bad3-4d21-4a5d-9354-4f75fe6ff478",
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_data_score_cond = GaussianNB().fit(train_set[:][0][:, 1].reshape(-1,1), train_set[:][1].squeeze(-1)).score(test_set[:][0][:, 1].reshape(-1,1), test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc - Cond: {np.round(orig_data_score_cond, 2)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfd46140-0da1-4371-9932-171f60db9cf7",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## CSVAE - No Adv."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfdc196a-0365-4396-b716-665309ef6d17",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "csvaena = CSVAENA(3, [1], latent_dim=2, w_dim=2, num_layers=2, recon_weight=20, z_kl_weight=2e-1)\n",
    "csvaena_trainer = EpochPyroTrainer(demo_epochs, csvaena, train_loader, test_loader)\n",
    "csvaena_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "185292b0-5133-4428-946e-164bb405bd5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = csvaena_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3ccc8e2-662a-41d5-ad57-6d0efb69fd4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "mi_z_w = MINE(4, [128]).cuda().mutual_information(z_s.cuda(), w_s.cuda())\n",
    "print(f'I(z ; w) = {np.round(mi_z_w, 3)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b4ff9c0-d09f-40bf-ba92-f672fc71209c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(z_s[:,0], z_s[:,1], c=test_set[:][1].numpy(), cmap=cmap, s=10, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7ba2be4-334d-44f1-87b1-fa32798e17ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:,0], w_s[:,1], c=test_set[:][1].numpy(), cmap=cmap, s=10, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.gca().set_box_aspect(1)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d645ddc3-5b8f-4fb3-8a2b-7770585c3ecb",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15,15))\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "scatter = ax.scatter(recons[:,0], recons[:,1], recons[:,2], c=train_set[:][1].numpy(), cmap=cmap)\n",
    "\n",
    "ax.grid(False)\n",
    "\n",
    "ax.set_xlabel('X', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_ylabel('Y', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_zlabel('Z', weight='bold', labelpad=7, fontsize=18)\n",
    "\n",
    "ax.set_xlim(0,30)\n",
    "ax.set_ylim(0,30)\n",
    "ax.set_zlim(0,30)\n",
    "ax.axis('off')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92adea71-de0d-4aa7-9e96-f0eb073445f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print((test_set[:][0] - recons).square().mean().sqrt())\n",
    "print(-1 * dist.Normal(recons.mean(dim=0), recons.std(dim=0)).log_prob(test_set[:][0]).sum(dim=1).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc2ebcfc-7a57-4b0a-be34-130735a56a0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined = torch.concatenate((z_s, w_s), dim=1)\n",
    "\n",
    "assert combined.shape[0] == test_set[:][0].shape[0]\n",
    "score = GaussianNB().fit(combined, test_set[:][1].squeeze(-1)).score(combined, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d5f758a-4cfe-4607-b414-30a61276e503",
   "metadata": {},
   "outputs": [],
   "source": [
    "score = GaussianNB().fit(z_s, test_set[:][1].squeeze(-1)).score(z_s, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score_common, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7fe074b-94be-4a39-a10e-e470added33f",
   "metadata": {},
   "outputs": [],
   "source": [
    "score = GaussianNB().fit(w_s, test_set[:][1].squeeze(-1)).score(w_s, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score_cond, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4049a4e-bf25-4ddf-8035-176c04d00bd2",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## CSVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41648ac0-beed-4e02-9cc7-5e1b0d565773",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "csvae = CSVAE(3, [1], latent_dim=2, w_dim=2, num_layers=2, recon_weight=20, adversarial_weight=50, z_kl_weight=2e-1)  \n",
    "csvae_trainer = AdversarialEpochPyroTrainer(demo_epochs, 1, 1, csvae, train_loader, test_loader)\n",
    "csvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ed50c6e-022f-451a-90c5-4d0aec249284",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = csvae_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b0a4d27-2d5e-4ffe-943d-e5127041eeac",
   "metadata": {},
   "outputs": [],
   "source": [
    "mi_z_w = MINE(4, [128]).cuda().mutual_information(z_s.cuda(), w_s.cuda())\n",
    "print(f'I(z ; w) = {np.round(mi_z_w, 3)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25035034-759b-40b3-a9c7-9e5e6eba22d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "print((test_set[:][0] - recons).square().mean().sqrt())\n",
    "print(-1 * dist.Normal(recons.mean(dim=0), recons.std(dim=0)).log_prob(test_set[:][0]).sum(dim=1).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b36d743-a2bb-41c7-8605-ce48df186da8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(z_s[:,0], z_s[:,1], c=test_set[:][1].numpy(), cmap=cmap, s=10, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1163ce1-3154-4bc9-95c0-4fddb3129f59",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:,0], w_s[:,1], c=test_set[:][1].numpy(), cmap=cmap, s=10, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.gca().set_box_aspect(1)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1d6bdbf-8947-4042-9f51-6133194ae538",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15,15))\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "scatter = ax.scatter(recons[:,0], recons[:,1], recons[:,2], c=test_set[:][1].numpy(), cmap=cmap)\n",
    "\n",
    "ax.grid(False)\n",
    "\n",
    "ax.set_xlabel('X', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_ylabel('Y', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_zlabel('Z', weight='bold', labelpad=7, fontsize=18)\n",
    "\n",
    "ax.set_xlim(0,30)\n",
    "ax.set_ylim(0,30)\n",
    "\n",
    "ax.axis('off')\n",
    "\n",
    "\n",
    "\n",
    "ax.set_zlim(0,30)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ba61c79-1b0b-477d-b2e5-7cb7d3ec1b1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined = torch.concatenate((z_s, w_s), dim=1)\n",
    "\n",
    "assert combined.shape[0] == test_set[:][0].shape[0]\n",
    "score = GaussianNB().fit(combined, test_set[:][1].squeeze(-1)).score(combined, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6eb1c74-0cc2-4801-a3ee-23620f243b89",
   "metadata": {},
   "outputs": [],
   "source": [
    "score = GaussianNB().fit(z_s, test_set[:][1].squeeze(-1)).score(z_s, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score_common, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1890648-32a0-49a3-b3c1-887b27858884",
   "metadata": {},
   "outputs": [],
   "source": [
    "score = GaussianNB().fit(w_s, test_set[:][1].squeeze(-1)).score(w_s, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score_cond, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db185b21-618f-41db-9a4e-291d90501375",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## HCSVAE - No Adv."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef68b399-4e39-439b-936d-a4dc0d9110d3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "hcsvaena = HCSVAENA(3, [1], latent_dim=2, w_dim=2, num_layers=2, recon_weight=20, z_kl_weight=2e-1)\n",
    "hcsvaena_trainer = EpochPyroTrainer(demo_epochs, hcsvaena, train_loader, test_loader)\n",
    "hcsvaena_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0712f18f-5891-4a84-ae85-dacb62963eb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = hcsvaena_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efe473ed-87c8-4979-831b-52b4a11cc8a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "mi_z_w = MINE(4, [128]).cuda().mutual_information(z_s.cuda(), w_s.cuda())\n",
    "print(f'I(z ; w) = {np.round(mi_z_w, 3)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fa3b934-4ea6-45c1-b47a-ba59013a568a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print((test_set[:][0] - recons).square().mean().sqrt())\n",
    "print(-1 * dist.Normal(recons.mean(dim=0), recons.std(dim=0)).log_prob(test_set[:][0]).sum(dim=1).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "908e279a-f134-4581-a58a-4be5d70d69db",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(z_s[:,0], z_s[:,1], c=test_set[:][1].numpy(), cmap=cmap, s=10, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00bc0002-6930-4433-85c3-5cc159c11a9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:,0], w_s[:,1], c=test_set[:][1].numpy(), cmap=cmap, s=10, alpha=0.6)\n",
    "plt.axis('off')\n",
    "\n",
    "plt.gca().set_box_aspect(1)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37d302fd-4c0c-432c-9a69-dc283c6c505e",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15,15))\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "scatter = ax.scatter(recons[:,0], recons[:,1], recons[:,2], c=test_set[:][1].numpy(), cmap=cmap)\n",
    "\n",
    "ax.grid(False)\n",
    "\n",
    "ax.set_xlabel('X', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_ylabel('Y', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_zlabel('Z', weight='bold', labelpad=7, fontsize=18)\n",
    "\n",
    "ax.axis('off')\n",
    "\n",
    "ax.set_xlim(0,30)\n",
    "ax.set_ylim(0,30)\n",
    "ax.set_zlim(0,30)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "005c247b-3974-42e4-9a36-08938adc2b12",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined = torch.concatenate((z_s, w_s), dim=1)\n",
    "\n",
    "assert combined.shape[0] == test_set[:][0].shape[0]\n",
    "score = GaussianNB().fit(combined, test_set[:][1].squeeze(-1)).score(combined, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a47b9ca6-4fb0-479b-bdfe-266f0191db54",
   "metadata": {},
   "outputs": [],
   "source": [
    "score = GaussianNB().fit(z_s, test_set[:][1].squeeze(-1)).score(z_s, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score_common, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba4b6bc3-77a6-4d69-8f86-291ce2855f3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "score = GaussianNB().fit(w_s, test_set[:][1].squeeze(-1)).score(w_s, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score_cond, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1eb32a20-3d6c-4649-b58b-9eb0a188bd4d",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## DIVA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0a8e21f-0574-436e-af69-3e8ea857ef82",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "diva = SDIVA(3, [1], latent_dim=2, w_dim=2, num_layers=2, recon_weight=20, kl_weight=2e-1)\n",
    "diva_trainer = EpochPyroTrainer(demo_epochs, diva, train_loader, test_loader)\n",
    "diva_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb974689-e1d6-4796-bbd8-99b3807cdceb",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = diva_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55bcdd21-0aa9-464b-aae4-5c83636ea138",
   "metadata": {},
   "outputs": [],
   "source": [
    "mi_z_w = MINE(4, [128]).cuda().mutual_information(z_s.cuda(), w_s.cuda())\n",
    "print(f'I(z ; w) = {np.round(mi_z_w, 3)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f723b270-add9-4e22-8a8f-c4137bb57e60",
   "metadata": {},
   "outputs": [],
   "source": [
    "print((test_set[:][0] - recons).square().mean().sqrt())\n",
    "print(-1 * dist.Normal(recons.mean(dim=0), recons.std(dim=0)).log_prob(test_set[:][0]).sum(dim=1).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8031a984-8d46-4af8-b3b2-7657b98fe3e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(z_s[:,0], z_s[:,1], c=test_set[:][1].numpy(), cmap=cmap, s=10, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ada9e9c-b962-455d-b815-93552c3e193c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:,0], w_s[:,1], c=test_set[:][1].numpy(), cmap=cmap, s=10, alpha=0.6)\n",
    "plt.axis('off')\n",
    "\n",
    "plt.gca().set_box_aspect(1)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af4aabe7-f01c-4fc3-97e0-ffe6d9454a3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15,15))\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "scatter = ax.scatter(recons[:,0], recons[:,1], recons[:,2], c=test_set[:][1].numpy(), cmap=cmap)\n",
    "\n",
    "\n",
    "ax.set_xlabel('X', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_ylabel('Y', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_zlabel('Z', weight='bold', labelpad=7, fontsize=18)\n",
    "\n",
    "ax.set_xlim(0,30)\n",
    "ax.set_ylim(0,30)\n",
    "ax.set_zlim(0,30)\n",
    "\n",
    "ax.axis('off')\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8036dabb-3f36-4937-b649-6165a7ba7a6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined = torch.concatenate((z_s, w_s), dim=1)\n",
    "\n",
    "assert combined.shape[0] == test_set[:][0].shape[0]\n",
    "score = GaussianNB().fit(combined, test_set[:][1].squeeze(-1)).score(combined, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc6e8e66-0b1d-400e-bd69-91861cb60fe5",
   "metadata": {},
   "outputs": [],
   "source": [
    "score = GaussianNB().fit(z_s, test_set[:][1].squeeze(-1)).score(z_s, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score_common, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "164cca18-cd79-4cd4-a87b-e84efe18e238",
   "metadata": {},
   "outputs": [],
   "source": [
    "score = GaussianNB().fit(w_s, test_set[:][1].squeeze(-1)).score(w_s, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score_cond, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48e60a38-0946-4b7c-a767-2a7584562bf8",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## CCVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e95161e5-cecc-4066-9ee9-4ecede0a6469",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "ccvae = CCVAE(3, [1], latent_dim=2, w_dim=2, num_layers=2, recon_weight=20, kl_weight=2e-1)\n",
    "ccvae_trainer = EpochPyroTrainer(demo_epochs, ccvae, train_loader, test_loader)\n",
    "ccvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "190bd5dc-34f0-40e6-bbf7-54a1c3bb576f",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = ccvae_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d52d42f5-56df-4535-bcb3-729aa542844a",
   "metadata": {},
   "outputs": [],
   "source": [
    "mi_z_w = MINE(4, [128]).cuda().mutual_information(z_s.cuda(), w_s.cuda())\n",
    "print(f'I(z ; w) = {np.round(mi_z_w, 3)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81b82cd8-9b4b-47e2-8ea1-bf12487054ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "print((test_set[:][0] - recons).square().mean().sqrt())\n",
    "print(-1 * dist.Normal(recons.mean(dim=0), recons.std(dim=0)).log_prob(test_set[:][0]).sum(dim=1).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df7728f0-59bd-4dd4-bccc-b3e1bf228d1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(z_s[:,0], z_s[:,1], c=test_set[:][1].numpy(), cmap=cmap, s=10, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3813d1c-50b8-494c-9abb-3ae590fc85cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:,0], w_s[:,1], c=test_set[:][1].numpy(), cmap=cmap, s=10, alpha=0.6)\n",
    "plt.axis('off')\n",
    "\n",
    "plt.gca().set_box_aspect(1)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dbbf24e-485e-4ed4-8edd-4ff0c47996c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15,15))\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "scatter = ax.scatter(recons[:,0], recons[:,1], recons[:,2], c=test_set[:][1].numpy(), cmap=cmap)\n",
    "\n",
    "ax.grid(False)\n",
    "\n",
    "ax.set_xlabel('X', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_ylabel('Y', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_zlabel('Z', weight='bold', labelpad=7, fontsize=18)\n",
    "\n",
    "ax.set_xlim(0,30)\n",
    "ax.set_ylim(0,30)\n",
    "ax.set_zlim(0,30)\n",
    "\n",
    "ax.axis('off')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44c90e9a-0cc7-4d4e-8380-5ae75f9f9a72",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined = torch.concatenate((z_s, w_s), dim=1)\n",
    "\n",
    "assert combined.shape[0] == test_set[:][0].shape[0]\n",
    "score = GaussianNB().fit(combined, test_set[:][1].squeeze(-1)).score(combined, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "415786d0-c0f5-440c-83cd-9165f05fbb37",
   "metadata": {},
   "outputs": [],
   "source": [
    "score = GaussianNB().fit(z_s, test_set[:][1].squeeze(-1)).score(z_s, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score_common, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31504a2e-bcd4-4afe-83e8-65617cb64aeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "score = GaussianNB().fit(w_s, test_set[:][1].squeeze(-1)).score(w_s, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score_cond, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73d81c07-b61d-4f3c-9a76-b3d0b0e521c5",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## HCSVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8df96f19-6652-4291-ac14-4dbe07b5c447",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "hcsvae = HCSVAE(3, [1], latent_dim=2, w_dim=2, num_layers=2, recon_weight=20, adversarial_weight=50, z_kl_weight=5e-1)\n",
    "hcsvae_trainer = AdversarialEpochPyroTrainer(demo_epochs, 1, 1, hcsvae, train_loader, test_loader)\n",
    "hcsvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b01530d3-612a-4275-9b2f-db717f7ce9fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = hcsvae_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acb8edd4-9f75-4093-8a79-4811d1137e38",
   "metadata": {},
   "outputs": [],
   "source": [
    "mi_z_w = MINE(4, [128]).cuda().mutual_information(z_s.cuda(), w_s.cuda())\n",
    "print(f'I(z ; w) = {np.round(mi_z_w, 3)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82a1307c-0e58-4866-b399-eee8f9a3c1dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "print((test_set[:][0] - recons).square().mean().sqrt())\n",
    "print(-1 * dist.Normal(recons.mean(dim=0), recons.std(dim=0)).log_prob(test_set[:][0]).sum(dim=1).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5f51c79-2fc8-4064-81c4-9ad8868114c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(z_s[:,0], z_s[:,1], c=test_set[:][1].numpy(), cmap=cmap, s=10, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90b10486-72ca-45c7-b17f-2d28ad7fc83f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:,0], w_s[:,1], c=test_set[:][1].numpy(), cmap=cmap, s=10, alpha=0.6)\n",
    "plt.axis('off')\n",
    "\n",
    "plt.gca().set_box_aspect(1)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab730432-fe4a-427c-9e91-9dae5e418593",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15,15))\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "scatter = ax.scatter(recons[:,0], recons[:,1], recons[:,2], c=test_set[:][1].numpy(), cmap=cmap)\n",
    "\n",
    "ax.grid(False)\n",
    "\n",
    "ax.set_xlabel('X', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_ylabel('Y', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_zlabel('Z', weight='bold', labelpad=7, fontsize=18)\n",
    "\n",
    "ax.set_xlim(0,30)\n",
    "ax.set_ylim(0,30)\n",
    "ax.set_zlim(0,30)\n",
    "\n",
    "ax.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "859250fc-311b-436d-912c-7bb81ba56787",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined = torch.concatenate((z_s, w_s), dim=1)\n",
    "\n",
    "assert combined.shape[0] == test_set[:][0].shape[0]\n",
    "score = GaussianNB().fit(combined, test_set[:][1].squeeze(-1)).score(combined, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c553b48-c17e-45d4-bc3f-e80c5be0103e",
   "metadata": {},
   "outputs": [],
   "source": [
    "score = GaussianNB().fit(z_s, test_set[:][1].squeeze(-1)).score(z_s, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score_common, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94ee41e7-8500-44af-a36b-a476b6d7c5e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "score = GaussianNB().fit(w_s, test_set[:][1].squeeze(-1)).score(w_s, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score_cond, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6db86717-7848-4a2b-bbd0-67d0a40bcf76",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## DISCoVeR (Ours)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5edcfef0-6b3f-4597-9163-5e5674051373",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "dlvae = DLVAE(3, [1], latent_dim=2, w_dim=2, num_layers=2, recon_weight=9e-1, recon_weight_z=1e-1, z_kl_weight=2e-1, w_kl_weight=2e-1, adversarial_weight=8)\n",
    "dlvae_trainer = AdversarialEpochPyroTrainer(demo_epochs, 1, 1, dlvae, train_loader, test_loader)\n",
    "dlvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45c3eaa1-c569-4ee8-918b-852d3d219d37",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = dlvae_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons_z = preds['rec_z'][0, 0].cpu()\n",
    "recons_w = preds['rec_w'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cdca18e-c84f-4df7-b254-8e2507f81f42",
   "metadata": {},
   "outputs": [],
   "source": [
    "mi_z_w = MINE(4, [128]).cuda().mutual_information(z_s.cuda(), w_s.cuda(), steps=100)\n",
    "print(f'I(z ; w) = {np.round(mi_z_w, 3)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31352f6f-9ccc-4890-bb1a-035731946521",
   "metadata": {},
   "outputs": [],
   "source": [
    "print((test_set[:][0] - recons_w).square().mean().sqrt())\n",
    "print(-1 * dist.Normal(recons_w.mean(dim=0), recons_w.std(dim=0)).log_prob(test_set[:][0]).sum(dim=1).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89688f03-02b3-49ee-b413-e62e9512c3d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(z_s[:,0], z_s[:,1], c=test_set[:][1].numpy(), cmap=cmap, s=10, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5956b42-2e5b-4750-ae50-31b4d32e6cc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(w_s[:,0], w_s[:,1], c=test_set[:][1].numpy(), cmap=cmap.reversed(), s=10, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.gca().set_box_aspect(1)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35331d34-7497-4c83-8c7e-d37b192c642b",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15,15))\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "scatter = ax.scatter(recons_z[:,0], recons_z[:,1], recons_z[:,2], c=test_set[:][1].numpy(), cmap=cmap)\n",
    "\n",
    "ax.axis('off')\n",
    "\n",
    "ax.set_xlabel('X', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_ylabel('Y', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_zlabel('Z', weight='bold', labelpad=7, fontsize=18)\n",
    "\n",
    "ax.set_xlim(0,30)\n",
    "ax.set_ylim(0,30)\n",
    "ax.set_zlim(0,30)\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcc7e6bf-aca2-4db4-a8e0-92d835be0f18",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15,15))\n",
    "ax = fig.add_subplot(projection='3d')\n",
    "scatter = ax.scatter(recons_w[:,0], recons_w[:,1], recons_w[:,2], c=test_set[:][1].numpy(), cmap=cmap)\n",
    "\n",
    "ax.grid(False)\n",
    "ax.axis('off')\n",
    "\n",
    "ax.set_xlabel('X', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_ylabel('Y', weight='bold', labelpad=30, fontsize=18)\n",
    "ax.set_zlabel('Z', weight='bold', labelpad=7, fontsize=18)\n",
    "\n",
    "ax.set_xlim(0,30)\n",
    "ax.set_ylim(0,30)\n",
    "ax.set_zlim(0,30)\n",
    "ax.axis('off')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a516ef07-1c67-4c3c-b6c4-e57e6470aa03",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined = torch.concatenate((z_s, w_s), dim=1)\n",
    "\n",
    "assert combined.shape[0] == test_set[:][0].shape[0]\n",
    "score = GaussianNB().fit(combined, test_set[:][1].squeeze(-1)).score(combined, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f403a4a6-d83a-4198-8ca1-74c880a2bc5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "score = GaussianNB().fit(z_s, test_set[:][1].squeeze(-1)).score(z_s, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score_common, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "968d9acd-e5fc-493b-9e95-8c541fd40baf",
   "metadata": {},
   "outputs": [],
   "source": [
    "score = GaussianNB().fit(w_s, test_set[:][1].squeeze(-1)).score(w_s, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Diff: {np.round(abs(np.round(orig_data_score_cond, 2) - np.round(score, 4)), 4)}' )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
