{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd194bc5-0032-4c7a-957f-f9a622d90391",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os, pickle\n",
    "sys.path.append('./src/')\n",
    "\n",
    "from VAE_variants import VAE, CVAE, CSVAENA, CSVAE, HCSVAENA, HCSVAE, DLVAE, SDIVA, CCVAE\n",
    "from VAE_trainers import EpochPyroTrainer, AdversarialEpochPyroTrainer, ThresholdPyroTrainer, AdversarialThresholdPyroTrainer\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "from sklearn.naive_bayes import GaussianNB\n",
    "from tqdm import tqdm, trange\n",
    "import pyro.distributions as dist\n",
    "\n",
    "import torch, pyro\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pyro.optim as opt\n",
    "import seaborn as sns\n",
    "\n",
    "demo_epochs=5\n",
    "\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "cmap = LinearSegmentedColormap.from_list(\"cmap\", [\"#F23E2E\", \"#5888A6\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ef2d5df-f7f4-4b5d-8a75-57fd9fd8cfa1",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81c341ac-fb66-4999-ba38-263ea0fee4d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.utils.data as utils\n",
    "\n",
    "#################### DATASET PARAMS #########################################################################################\n",
    "n_samples = 30000\n",
    "batch_size = 64\n",
    "##################################################################################################################################\n",
    "\n",
    "# Generate latent variables\n",
    "z_true = np.random.randn(n_samples)       \n",
    "w_true = np.random.randn(n_samples)       \n",
    "\n",
    "# Observation: x = z + w, so x ~ N(0,2)\n",
    "x = z_true + w_true\n",
    "\n",
    "# Binary label: y = 1 if w > 0, else 0\n",
    "y = (w_true > 0).astype(np.int64)\n",
    "y = torch.tensor(y).float().reshape(-1,1)\n",
    "\n",
    "# Convert to PyTorch tensors\n",
    "x = torch.tensor(x, dtype=torch.float32).unsqueeze(1)  # shape: [N,1]\n",
    "y = y  # shape: [N,1]\n",
    "\n",
    "\n",
    "dataset = utils.TensorDataset(x, y, torch.hstack((torch.FloatTensor(z_true.reshape(-1,1)), torch.FloatTensor(w_true.reshape(-1,1)))))\n",
    "train_set, test_set = utils.random_split(dataset, [0.5, 0.5])  \n",
    "train_set, test_set = utils.TensorDataset(*train_set[:]), utils.TensorDataset(*test_set[:])\n",
    "train_loader, test_loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=batch_size), torch.utils.data.DataLoader(test_set, shuffle=False, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9ad4d1c-efdf-4e95-ad09-a183b0b88a91",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(w_true, color='#9E666F', alpha=0.6, bins=75)\n",
    "plt.xlim(-6,6)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41b89477-d2f7-4a6a-abc6-59f5d5f6ac71",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(z_true, color='#9E666F', alpha=0.6, bins=75)\n",
    "plt.xlim(-6,6)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdd1cd96-bfa8-47cb-9710-d1b8ace04ea7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(x.flatten().numpy(), color='#9E666F', alpha=0.6, bins=75)\n",
    "plt.xlim(-6,6)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86b91e2e-32a0-4b1b-a1b3-54b6ddd45218",
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_data_score = GaussianNB().fit(train_set[:][0], train_set[:][1].squeeze(-1)).score(train_set[:][0], train_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc - Combined: {np.round(orig_data_score, 2)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfd46140-0da1-4371-9932-171f60db9cf7",
   "metadata": {},
   "source": [
    "## CSVAE - No Adv."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfdc196a-0365-4396-b716-665309ef6d17",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "csvaena = CSVAENA(1, [1], latent_dim=1, w_dim=1, hidden_dim=8)\n",
    "csvaena_trainer = EpochPyroTrainer(demo_epochs, csvaena, train_loader, test_loader)\n",
    "csvaena_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "185292b0-5133-4428-946e-164bb405bd5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "csvaena_trainer._predictive_setup(s=1)\n",
    "preds = csvaena_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c3f6478-c3c2-437b-94e1-72a9169f629f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dist.Normal(recons.mean(), recons.std()).log_prob(test_set[:][0]).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74ca9744-5094-4866-b523-fb32368be6da",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(recons.flatten().numpy(), orient='y', c=test_set[:][1].numpy(), cmap=cmap, size=5)\n",
    "plt.xlim(-6,6)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83c7293e-defc-44d2-afe3-90ef1f8e6f43",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(z_s, orient='y', c=test_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.xlim(-6,6)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d76cdad-cce5-4fa2-931f-0770ff5d03da",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(w_s, orient='y', c=test_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.xlim(-6,6)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e548f215-a6b8-4169-a403-9854cb6a558e",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined = torch.concatenate((z_s, w_s), dim=1)\n",
    "\n",
    "assert combined.shape[0] == test_set[:][0].shape[0]\n",
    "score = GaussianNB().fit(combined, test_set[:][1].squeeze(-1)).score(combined, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 2)}')\n",
    "print(f'Dfif: {np.round(abs(orig_data_score - score), 2)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ed961c6-5f74-4170-b33c-a0b75bfaaebb",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(w_s.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-1,6)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00573ebd-27c0-487b-b9ef-77024fbe9e24",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(z_s.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-6,6)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b4ff9c0-d09f-40bf-ba92-f672fc71209c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot X Samples\n",
    "plt.hist(recons.flatten().numpy(), color='purple', alpha=0.6, bins=500)\n",
    "plt.title(\"Reconstructed X Samples\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4049a4e-bf25-4ddf-8035-176c04d00bd2",
   "metadata": {},
   "source": [
    "## CSVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fbf55bd-b365-4be5-9d90-d6f64f44c8f2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "csvae = CSVAE(1, [1], latent_dim=1, w_dim=1, hidden_dim=8, recon_weight=2.5, adversarial_weight=20, w_kl_weight=5e-1)\n",
    "csvae_trainer = AdversarialEpochPyroTrainer(demo_epochs, 1, 1, csvae, train_loader, test_loader)\n",
    "csvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ed50c6e-022f-451a-90c5-4d0aec249284",
   "metadata": {},
   "outputs": [],
   "source": [
    "csvae_trainer._predictive_setup(s=1)\n",
    "preds = csvae_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84ff78ab-dc60-4760-ae5d-a59a8da75dd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dist.Normal(recons.mean(), recons.std()).log_prob(test_set[:][0]).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d97597f4-82bd-4a98-8122-f2f2639356ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot X Samples\n",
    "sns.stripplot(recons.flatten().numpy(), orient='y', c=test_set[:][1].numpy(), cmap=cmap)\n",
    "plt.xlim(-6,6)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4ba41e4-a9b7-411f-aa2d-7ea0a62c024d",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(z_s, orient='y', c=test_set[:][1].numpy(), cmap=cmap)\n",
    "plt.xlim(-6,6)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41b4f3f1-c746-4604-a5b7-9f3e4c5f2894",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(w_s, orient='y', c=test_set[:][1].numpy(), cmap=cmap)\n",
    "plt.xlim(-6,6)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ded805c-5c02-4d75-ba70-68c4122286a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined = torch.concatenate((z_s, w_s), dim=1)\n",
    "\n",
    "assert combined.shape[0] == train_set[:][0].shape[0]\n",
    "score = GaussianNB().fit(combined, test_set[:][1].squeeze(-1)).score(combined, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 2)}')\n",
    "print(f'Dfif: {np.round(abs(orig_data_score - score), 2)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60a96e3b-c4a3-40d4-b5df-62bb7540a79d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(z_s.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-6,6)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa65ce28-7443-41cc-b0ef-954e9f41f7aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(w_s.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-1,6)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2e3b41d-1841-4158-911e-151c4086f20b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot X Samples\n",
    "plt.hist(recons.flatten().numpy(), color='purple', alpha=0.6, bins=500)\n",
    "plt.title(\"Reconstructed X Samples\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db185b21-618f-41db-9a4e-291d90501375",
   "metadata": {},
   "source": [
    "## HCSVAE - No Adv."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef68b399-4e39-439b-936d-a4dc0d9110d3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "hcsvaena = HCSVAENA(1, [1], latent_dim=1, w_dim=1, hidden_dim=8, w_kl_weight=5e-1)\n",
    "hcsvaena_trainer = EpochPyroTrainer(demo_epochs, hcsvaena, train_loader, test_loader)\n",
    "hcsvaena_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0712f18f-5891-4a84-ae85-dacb62963eb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "hcsvaena_trainer._predictive_setup(s=1)\n",
    "preds = hcsvaena_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54d24682-6258-4c36-b598-a1f67c5087a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dist.Normal(recons.mean(), recons.std()).log_prob(test_set[:][0]).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9ef24cc-a2c0-4c6d-af92-72bb03fbeed0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot X Samples\n",
    "sns.stripplot(recons.flatten().numpy(), orient='y', c=test_set[:][1].numpy(), cmap=cmap, size=5)\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f74ec4dd-ea82-4f57-b9ef-a8d56455ec3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(z_s, orient='y', c=test_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90281e78-c676-4d71-8c0d-bff43947b652",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(w_s, orient='y', c=test_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.xlim(-1,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a495d6ab-8cd5-4589-aef3-4ef8ab55eaa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined = torch.concatenate((z_s, w_s), dim=1)\n",
    "\n",
    "assert combined.shape[0] == test_set[:][0].shape[0]\n",
    "score = GaussianNB().fit(combined, test_set[:][1].squeeze(-1)).score(combined, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Dfif: {np.round(abs(np.round(orig_data_score, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e418ee7-bad4-4cd6-bc43-1a1bac4c6cdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(w_s.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-1,6)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1cd55bb-fb18-4aa5-9898-0f8150a2fa5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(z_s.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-6,6)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "908e279a-f134-4581-a58a-4be5d70d69db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot X Samples\n",
    "plt.hist(recons.flatten().numpy(), color='purple', alpha=0.6, bins=500)\n",
    "plt.title(\"Reconstructed X Samples\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73d81c07-b61d-4f3c-9a76-b3d0b0e521c5",
   "metadata": {},
   "source": [
    "## HCSVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8df96f19-6652-4291-ac14-4dbe07b5c447",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "hcsvae = HCSVAE(1, [1], latent_dim=1, w_dim=1, hidden_dim=8, recon_weight=2.5, adversarial_weight=20, w_kl_weight=5e-1)\n",
    "hcsvae_trainer = AdversarialEpochPyroTrainer(demo_epochs, 1, 1, hcsvae, train_loader, test_loader)\n",
    "hcsvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b01530d3-612a-4275-9b2f-db717f7ce9fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "hcsvae_trainer._predictive_setup(s=1)\n",
    "preds = hcsvae_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc44f3d5-ba4f-437e-9354-853066ed2d65",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dist.Normal(recons.mean(), recons.std()).log_prob(test_set[:][0]).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6449a0a-6465-47e4-bb9c-0b5b92723807",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot X Samples\n",
    "sns.stripplot(recons.flatten().numpy(), orient='y', c=test_set[:][1].numpy(), cmap=cmap, size=5)\n",
    "plt.axis('off')\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bdf7cd0-6c1d-46dc-860b-576fc2f4c1be",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(z_s, orient='y', c=test_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67859f92-0c35-4600-b088-4d6af77c7252",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(w_s, orient='y', c=test_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5680fa29-adc4-470a-8252-dd8d0d399f5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined = torch.concatenate((z_s, w_s), dim=1)\n",
    "\n",
    "assert combined.shape[0] == test_set[:][0].shape[0]\n",
    "score = GaussianNB().fit(combined, test_set[:][1].squeeze(-1)).score(combined, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Dfif: {np.round(abs(np.round(orig_data_score, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d029327-8e39-4b71-af2d-a51e18a37e50",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(z_s.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-6,6)\n",
    "plt.axis('off')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8470068-adce-47bb-b13d-2ac977df3cc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(w_s.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-1,6)\n",
    "plt.axis('off')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5f51c79-2fc8-4064-81c4-9ad8868114c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot X Samples\n",
    "plt.hist(recons.flatten().numpy(), color='purple', alpha=0.6, bins=500)\n",
    "plt.title(\"Reconstructed X Samples\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36a97d11-4d63-4283-b990-49a148a31803",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## DIVA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "217bd712-1d11-4143-b7c5-d342f48e1d44",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "diva = SDIVA(1, [1], latent_dim=1, w_dim=1, hidden_dim=8)\n",
    "diva_trainer = EpochPyroTrainer(demo_epochs, diva, train_loader, test_loader)\n",
    "diva_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62aee1e4-d7b3-457c-93e7-4b9092af7020",
   "metadata": {},
   "outputs": [],
   "source": [
    "diva_trainer._predictive_setup(s=1)\n",
    "preds = diva_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fb86002-03c8-4baa-b8ed-c76a115c25cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dist.Normal(recons.mean(), recons.std()).log_prob(test_set[:][0]).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d6905bb-2f7f-4f2d-ae7b-259d8aba850b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot X Samples\n",
    "sns.stripplot(recons.flatten().numpy(), orient='y', c=test_set[:][1].numpy(), cmap=cmap, size=5)\n",
    "plt.axis('off')\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65c50bc2-c858-48b4-b88e-ebbe3b363946",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(z_s, orient='y', c=test_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6838318d-eb60-4e36-b865-a83ad0fab777",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(w_s, orient='y', c=test_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45fa88d2-934a-493a-b3de-fa612bc999c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined = torch.concatenate((z_s, w_s), dim=1)\n",
    "\n",
    "assert combined.shape[0] == test_set[:][0].shape[0]\n",
    "score = GaussianNB().fit(combined, test_set[:][1].squeeze(-1)).score(combined, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Dfif: {np.round(abs(orig_data_score - score), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9ad012a-6cfb-4f36-b5f9-079fff2a950e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(w_s.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-1,6)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1725015-8867-4404-8924-3521a4f69252",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(z_s.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-6,6)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cf70838-b812-4296-b919-f34a2b008e95",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot X Samples\n",
    "plt.hist(recons.flatten().numpy(), color='purple', alpha=0.6, bins=500)\n",
    "plt.title(\"Reconstructed X Samples\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d386cfe5-51b3-4d8a-bb47-14b8e29e5404",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## CCVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcb29b46-36a5-4e12-9864-6e94e9533c2e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "ccvae = CCVAE(1, [1], latent_dim=1, w_dim=1, hidden_dim=8)\n",
    "ccvae_trainer = EpochPyroTrainer(demo_epochs, ccvae, train_loader, test_loader)\n",
    "ccvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2448cd3c-ce06-4bae-9161-cbc92092ba06",
   "metadata": {},
   "outputs": [],
   "source": [
    "ccvae_trainer._predictive_setup(s=1)\n",
    "preds = ccvae_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3926b51-1358-4b3b-a0bd-9d9e8d153e5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dist.Normal(recons.mean(), recons.std()).log_prob(test_set[:][0]).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a9eb839-a5bf-406b-9539-535765ee94c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot X Samples\n",
    "sns.stripplot(recons.flatten().numpy(), orient='y', c=test_set[:][1].numpy(), cmap=cmap, size=5)\n",
    "plt.axis('off')\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97e6af55-b992-4b0a-bac5-ec4afbe744b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(z_s, orient='y', c=test_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe74550a-617b-4759-8074-4537a91b4f03",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(w_s, orient='y', c=test_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a619f2d3-01ed-47ee-b2de-2f2bcc4d6f73",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined = torch.concatenate((z_s, w_s), dim=1)\n",
    "\n",
    "assert combined.shape[0] == test_set[:][0].shape[0]\n",
    "score = GaussianNB().fit(combined, test_set[:][1].squeeze(-1)).score(combined, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Dfif: {np.round(abs(orig_data_score - score), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "282df93c-2aee-47a3-8e49-53881b7d40ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(w_s.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-1,6)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a0e121a-6018-4675-b399-b53f889e6b8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(z_s.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-6,6)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28f9105e-124b-4c03-8fe1-16d3a4581630",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot X Samples\n",
    "plt.hist(recons.flatten().numpy(), color='purple', alpha=0.6, bins=500)\n",
    "plt.title(\"Reconstructed X Samples\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd3d5206-2a02-41d6-bd1b-e638398b1ec4",
   "metadata": {},
   "source": [
    "## DISCoVeR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "287feea5-ecaa-4e6f-ada3-bf26e6a902f5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "dlvae = DLVAE(1, [1], latent_dim=1, w_dim=1, hidden_dim=8, recon_weight=7e-1, recon_weight_z=3e-1, w_kl_weight=2e-1, z_kl_weight=7e-1, adversarial_weight=8e-1)\n",
    "dlvae_trainer = AdversarialEpochPyroTrainer(demo_epochs, 1, 1, dlvae, train_loader, test_loader)\n",
    "dlvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "339dd707-1ff5-4483-a464-a36ee1e24bef",
   "metadata": {},
   "outputs": [],
   "source": [
    "dlvae_trainer._predictive_setup(s=1)\n",
    "preds = dlvae_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons_z = preds['rec_z'][0, 0].cpu()\n",
    "recons_w = preds['rec_w'][0, 0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6f422df-646d-4527-8ee6-a15e8e00f5a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dist.Normal(recons_w.mean(), recons_w.std()).log_prob(test_set[:][0]).mean())\n",
    "print(dist.Normal(recons_z.mean(), recons_z.std()).log_prob(test_set[:][0]).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2dab56d-279c-4ef7-9999-fc96a45af061",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot X Samples - From Z\n",
    "sns.stripplot(recons_w.flatten().numpy(), orient='y', c=test_set[:][1].numpy(), cmap=cmap, size=5)\n",
    "plt.axis('off')\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e307d84d-bb09-46b3-a9f1-64bc016d8c4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot X Samples - From Z\n",
    "sns.stripplot(recons_z.flatten().numpy(), orient='y', c=test_set[:][1].numpy(), cmap=cmap, size=5)\n",
    "plt.axis('off')\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb95e894-b25d-4312-bfc4-bc782a5068a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(z_s, orient='y', c=test_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad7556e7-4514-4e97-8c48-d8edea7b7f8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.stripplot(w_s, orient='y', c=test_set[:][1].numpy(), cmap=cmap, alpha=0.6)\n",
    "plt.axis('off')\n",
    "plt.xlim(-6,6)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d999082-8579-42c8-95a6-1874669687c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined = torch.concatenate((z_s, w_s), dim=1)\n",
    "\n",
    "assert combined.shape[0] == test_set[:][0].shape[0]\n",
    "score = GaussianNB().fit(combined, test_set[:][1].squeeze(-1)).score(combined, test_set[:][1].squeeze(-1))\n",
    "print(f'Bayes classifier acc: {np.round(score, 4)}')\n",
    "print(f'Dfif: {np.round(abs(np.round(orig_data_score, 2) - np.round(score, 4)), 4)}' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96fc4030-48e9-4d15-837e-3f361524f299",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(z_s.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-6,6)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23b342ed-8a1d-4a75-ae1c-54994fc1bd07",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(w_s.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-6,6)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55416d64-5b9f-43c5-b171-180d7db98c79",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(recons_z.flatten().numpy(), color='#9E666F', alpha=0.6, bins=500)\n",
    "plt.xlim(-6,6)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "881835a0-3ae1-4a24-82a3-8269a649b62b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot X Samples\n",
    "fig, ax = plt.subplots(1,2,figsize=(16,12), sharex=True, sharey=True)\n",
    "\n",
    "ax[0].hist(recons_w.flatten().numpy(), color='purple', alpha=0.6, bins=500)\n",
    "ax[0].set_title(\"Reconstructed X Samples\")\n",
    "\n",
    "ax[1].hist(recons_z.flatten().numpy(), color='purple', alpha=0.6, bins=500)\n",
    "ax[1].set_title(\"Reconstructed X Samples - From Z\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d4b4571-9fa9-4b3b-9bda-bb7601347a6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_lats = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94678d2d-1c75-4b71-b2d8-3a1420d61860",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Expand for conditionals\n",
    "idxs = [555, 768, 433, 93]\n",
    "dlvae_trainer._predictive_setup(s=1000)\n",
    "preds = dlvae_trainer.predictive(*dlvae_trainer._send_args_to_device(dlvae_trainer.test_loader.dataset[:1000], dlvae_trainer.device))\n",
    "z_s = preds['z'].cpu()\n",
    "w_s = preds['w'].cpu()\n",
    "\n",
    "for idx in idxs:\n",
    "    x, y = test_set[idx][:2]\n",
    "    x, y = x.item(), y.item()\n",
    "\n",
    "    z = z_s[:, idx, :]\n",
    "    w = w_s[:, idx, :]\n",
    "\n",
    "    for i in range(1000):\n",
    "        df_lats.append([x, z[i].item(), w[i].item(), 'DISCoVeR'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7764d268-68d9-4848-afe4-e8f306848dc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import norm, truncnorm\n",
    "import pandas as pd\n",
    "\n",
    "# True posterior\n",
    "for idx in idxs:\n",
    "    x, y = test_set[idx][:2]\n",
    "    x, y = x.item(), y.item()\n",
    "\n",
    "    if y == 1:\n",
    "        a, b = (0 - x/2) / np.sqrt(0.5), np.inf\n",
    "\n",
    "    else:\n",
    "        a, b = -np.inf, (0 - x/2) / np.sqrt(0.5)\n",
    "    \n",
    "    z = norm.rvs(loc=x/2, scale=np.sqrt(0.5), size=1000)\n",
    "    w = truncnorm.rvs(loc=x/2, scale=np.sqrt(0.5), a=a, b=b, size=1000)\n",
    "\n",
    "    for i in range(1000):\n",
    "        df_lats.append([x, z[i], w[i], 'True Posterior'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "169efc12-9f79-4e34-8132-17d31150cfeb",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame(df_lats, columns=['X', 'Z', 'W', 'Model'])\n",
    "linestyles = ['solid', 'solid']\n",
    "alphas = [1,1] \n",
    "\n",
    "palette = sns.color_palette(\"colorblind\", n_colors=8)\n",
    "palette[2], palette[-2] = palette[-2], palette[2]\n",
    "palette[-1] = (0,0,0)\n",
    "palette = palette[-2:]\n",
    "\n",
    "\n",
    "from pyfonts import load_font\n",
    "\n",
    "# load font\n",
    "font = load_font(\n",
    "   font_url=\"https://github.com/stevenpetryk/computer-modern/blob/main/src/cmunrm.ttf?raw=true\"\n",
    ")\n",
    "\n",
    "# True posterior: Black, straight\n",
    "# different linepatterns / lower alpha for other methods\n",
    "# pick vibrant color for us\n",
    "\n",
    "fig, ax = plt.subplots(2,4,figsize=(10,5), sharey=False, sharex=True)\n",
    "\n",
    "x = df['X'].unique()[3]\n",
    "y = 1\n",
    "\n",
    "for x, y, axtrack, count in zip(list(df['X'].unique()), [0,0,1,1], [(0,1), (2,3), (0,1), (2,3)], range(4)):\n",
    "\n",
    "    cur_ax_1, cur_ax_2 = ax[count // 2][axtrack[0]], ax[count // 2][axtrack[1]]\n",
    "    \n",
    "    \n",
    "    xs = np.linspace(-5,5,1000)\n",
    "    \n",
    "    for i, model in enumerate(df['Model'].unique()):\n",
    "    \n",
    "        if model != \"True Posterior\":\n",
    "            df_sub = df[(df['X'] == x) & (df['Model'] == model)]\n",
    "            mu_z, sigma_z = df_sub['Z'].mean(), df_sub['Z'].std() \n",
    "            mu_w, sigma_w = df_sub['W'].mean(), df_sub['W'].std()\n",
    "    \n",
    "        else:\n",
    "            mu_z, sigma_z = x/2, np.sqrt(1/2) \n",
    "            mu_w, sigma_w = x/2, np.sqrt(1/2) \n",
    "    \n",
    "        \n",
    "    \n",
    "        if y == 1:\n",
    "            a, b = (0 - x/2) / sigma_w, np.inf\n",
    "    \n",
    "        else:\n",
    "            a, b = -np.inf, (0 - x/2) / sigma_w\n",
    "    \n",
    "        \n",
    "        p_z = norm.pdf(xs, mu_z, sigma_z)\n",
    "        if model == 'True Posterior':\n",
    "            p_w = truncnorm.pdf(xs, loc=mu_w, scale=sigma_w, a=a, b=b)\n",
    "    \n",
    "        else:\n",
    "            p_w = norm.pdf(xs, mu_w, sigma_w)\n",
    "    \n",
    "        cur_ax_1.plot(xs, p_z, c=palette[i], label=f'{model}', linestyle=linestyles[i], alpha=alphas[i])\n",
    "        cur_ax_2.plot(xs, p_w, c=palette[i], linestyle=linestyles[i], alpha=alphas[i])\n",
    "    \n",
    "    \n",
    "    \n",
    "    cur_ax_1.set_ylim((0,1.5))\n",
    "    cur_ax_2.set_ylim((0,1.5))\n",
    "    \n",
    "    cur_ax_1.spines[['top', 'right']].set_visible(False)\n",
    "    cur_ax_2.spines[['top', 'right']].set_visible(False)\n",
    "    \n",
    "    cur_ax_2.set_yticks([])\n",
    "    \n",
    "    \n",
    "    cur_ax_1.set_xticklabels(cur_ax_1.get_xticklabels(), font=font, fontsize=6)\n",
    "    cur_ax_2.set_xticklabels(cur_ax_2.get_xticklabels(), font=font, fontsize=6)\n",
    "    \n",
    "    cur_ax_1.set_yticklabels(cur_ax_1.get_yticklabels(), font=font, fontsize=6)\n",
    "    cur_ax_2.set_yticklabels(cur_ax_2.get_yticklabels(), font=font, fontsize=6)\n",
    "    \n",
    "    cur_ax_1.set_title('Z', font=font, fontsize=6)\n",
    "    cur_ax_2.set_title('W', font=font, fontsize=6)\n",
    "    \n",
    "    \n",
    "    cur_ax_1.set_title(f'Z | X = {np.round(x, 3)}', font=font, fontsize=6)\n",
    "    cur_ax_2.set_title(f'W | X = {np.round(x, 3)}, Y = {int(y)}', font=font, fontsize=6)\n",
    "\n",
    "    cur_ax_1.legend()\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
