{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3728b63-75f1-4c03-b20e-98bcf7ee7b2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os, requests\n",
    "sys.path.append('./src/')\n",
    "\n",
    "import scanpy as sc\n",
    "import anndata as ad\n",
    "\n",
    "from metrics import kmeans_ari\n",
    "from NBVAE_variants import ZINBVAE, ZINBCVAE, ZINBCSVAENA, ZINBCSVAE, ZINBHCSVAENA, ZINBHCSVAE, ZINBDLVAE, ZINBDIVA, ZINBCCVAE\n",
    "from VAE_trainers import EpochPyroTrainer, AdversarialEpochPyroTrainer, ThresholdPyroTrainer, AdversarialThresholdPyroTrainer\n",
    "from matplotlib.colors import LinearSegmentedColormap, ListedColormap\n",
    "import seaborn as sns\n",
    "\n",
    "from tqdm import trange, tqdm\n",
    "from umap import UMAP\n",
    "\n",
    "\n",
    "import torch, pyro\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import copy\n",
    "import pyro.optim as opt\n",
    "\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "\n",
    "cmap_trt = LinearSegmentedColormap.from_list(\"cmap\", [\"#42378C\", \"#D9A404\"])\n",
    "cmap_ct = ListedColormap(sns.color_palette('colorblind').as_hex())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d27b996-17ed-46bf-9f23-6dab3f083fe2",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6701337d-7e1a-44f6-921d-2d22a246dc7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, requests\n",
    "\n",
    "# Static data path, update when necessary\n",
    "DATA_PATH = \"https://www.dropbox.com/scl/fi/zxta2nf00p8a9do907rrv/kang_et_al_perturtbations_preprocessed.h5ad?rlkey=bk4pbuily0349borou6rnvjds&st=yvp1zxjd&dl=1\"\n",
    "NAME = \"kang_et_al_perturtbations_preprocessed.h5ad\"\n",
    "\n",
    "\n",
    "# Reorganize param paths\n",
    "save_path = \"./data/\" + NAME\n",
    "\n",
    "# Send download request\n",
    "headers = {\n",
    "    \"user-agent\": \"Wget/1.16 (linux-gnu)\"\n",
    "}  # Dropbox checks the agent for some reason\n",
    "response = requests.get(\n",
    "    DATA_PATH, headers=headers, stream=True, allow_redirects=True\n",
    ")\n",
    "\n",
    "# Get and process response\n",
    "\n",
    "## Good\n",
    "if response.status_code == 200:\n",
    "    os.makedirs(os.path.dirname(save_path), exist_ok=True)\n",
    "\n",
    "    # Write with progress bar\n",
    "    with tqdm(\n",
    "        total=int(response.headers.get(\"content-length\", 0)),\n",
    "        unit=\"B\",\n",
    "        unit_scale=True,\n",
    "    ) as progress_bar:\n",
    "        with open(save_path, \"wb\") as f:\n",
    "            for data in response.iter_content(1024):\n",
    "                progress_bar.update(len(data))\n",
    "                f.write(data)\n",
    "\n",
    "    print(f\"Object saved at {save_path}\")\n",
    "\n",
    "## Unexpected\n",
    "else:\n",
    "    print(response.__dict__)\n",
    "    raise Exception(\"Object not found at URL\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8601301d-c72a-4b4e-9055-31308b7e24de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is skipped, but left to describe preprocessing.\n",
    "# If you need the raw data for anything, the SeuratData package has it. Check here for more details: https://satijalab.org/seurat/archive/v3.2/immune_alignment.html\n",
    "\n",
    "# if not os.path.exists('./data/kang_et_al_perturtbations_preprocessed.h5ad'):\n",
    "#     anndata = ad.read_h5ad('./data/kang_et_al_perturtbations.h5ad')\n",
    "#     anndata.layers['raw'] = anndata.X.copy()\n",
    "    \n",
    "#     map_dict = {\n",
    "#         'B' : 'B', \n",
    "#         'B Activated' : 'B', \n",
    "#         'CD4 Memory T' : 'CD4 T', \n",
    "#         'CD4 Naive T': 'CD4 T', \n",
    "#         'CD8 T': 'CD8 T', \n",
    "#         'CD14 Mono': 'CD14 Mono',\n",
    "#         'CD16 Mono': 'CD16 Mono', \n",
    "#         'DC': 'DC', \n",
    "#         'Eryth': 'Eryth', \n",
    "#         'Mk': 'Mk', \n",
    "#         'NK': 'NK', \n",
    "#         'T activated': 'T', \n",
    "#         'pDC': 'DC',\n",
    "#     }\n",
    "    \n",
    "#     anndata.obs['seurat_annotations'] = anndata.obs['seurat_annotations'].astype(object).apply(lambda x: map_dict[x]).astype('category')    \n",
    "    \n",
    "#     sc.pp.filter_cells(anndata, min_genes=200)\n",
    "#     sc.pp.filter_genes(anndata, min_cells=3)\n",
    "#     sc.pp.normalize_total(anndata)\n",
    "#     sc.pp.log1p(anndata)\n",
    "    \n",
    "#     sc.pp.highly_variable_genes(anndata, n_top_genes=2000)\n",
    "#     anndata = anndata[:, anndata.var['highly_variable']]\n",
    "    \n",
    "#     anndata.write_h5ad('./data/kang_et_al_perturtbations_preprocessed.h5ad')\n",
    "    \n",
    "# else:\n",
    "\n",
    "anndata = ad.read_h5ad('./data/kang_et_al_perturtbations_preprocessed.h5ad')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a5637af-9d6c-4e3e-97ef-b81f9dec02b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.utils.data as utils\n",
    "\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "batch_size=128\n",
    "x = torch.FloatTensor(anndata.layers['raw'].copy())\n",
    "y = torch.FloatTensor(anndata.obs['stim'].cat.codes.to_numpy().reshape(-1,1).copy())\n",
    "y_info = torch.FloatTensor(anndata.obs['seurat_annotations'].cat.codes.to_numpy().reshape(-1,1).copy())\n",
    "\n",
    "dataset = utils.TensorDataset(x, y, y_info)\n",
    "train_set, test_set = dataset, dataset\n",
    "train_set, test_set = utils.TensorDataset(*train_set[:]), utils.TensorDataset(*test_set[:])\n",
    "train_loader, test_loader  = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=batch_size),  torch.utils.data.DataLoader(test_set, shuffle=True, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18839a86-b340-452e-91af-041fc4d7a670",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP(random_state=0, metric='correlation')\n",
    "umap_data = reducer.fit_transform(np.log1p(np.array(test_set[:][0])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad45a06d-061d-4b09-aa42-473da03bd509",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyfonts import load_font\n",
    "\n",
    "# load font\n",
    "font = load_font(\n",
    "   font_url=\"https://github.com/stevenpetryk/computer-modern/blob/main/src/cmunrm.ttf?raw=true\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6894445-0ce1-4edb-8574-a3d3b5e6514f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,5))\n",
    "\n",
    "scatter = plt.scatter(umap_data[:, 0], umap_data[:, 1], c=test_set[:][2], cmap=cmap_ct, s=5, alpha=1)\n",
    "\n",
    "\n",
    "plt.gca().set_aspect('equal')\n",
    "plt.yticks([])\n",
    "plt.xticks([])\n",
    "plt.gca().spines[['right', 'top']].set_visible(False)\n",
    "\n",
    "\n",
    "plt.ylabel('UMAP 2', fontsize=24, font=font)\n",
    "plt.xlabel('UMAP 1', fontsize=24, font=font)\n",
    "\n",
    "#leg = plt.legend(handles=scatter.legend_elements()[0], labels=list(anndata.obs['seurat_annotations'].cat.categories), fontsize=18)\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d40c33b-f9ee-4548-9711-028d313c73ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,5))\n",
    "\n",
    "\n",
    "scatter = plt.scatter(umap_data[:, 0], umap_data[:, 1], c=test_set[:][1], cmap=cmap_trt, s=1)\n",
    "\n",
    "plt.gca().set_aspect('equal')\n",
    "plt.yticks([])\n",
    "plt.xticks([])\n",
    "plt.gca().spines[['right', 'top']].set_visible(False)\n",
    "\n",
    "plt.ylabel('UMAP 2', font=font, fontsize=24)\n",
    "plt.xlabel('UMAP 1', font=font, fontsize=24)\n",
    "\n",
    "#plt.legend(handles=scatter.legend_elements()[0], labels=list(anndata.obs['stim'].cat.categories), fontsize=18)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6f26018-a22f-4938-9272-edf0dff011e7",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## CSVAE - No Adv."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf47b078-6426-414b-bf6f-64b2c5abe32e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "csvaena = ZINBCSVAENA(2000, [1], latent_dim=10, w_dim=2, num_layers=1, hidden_dim=128, recon_weight=1, z_kl_weight=1e-4, w_kl_weight=1)\n",
    "csvaena_trainer = ThresholdPyroTrainer(0, 50, csvaena, train_loader, test_loader, opt.AdamW({\"lr\": 1e-3}))\n",
    "csvaena_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ad6915c-97f3-49f2-80c1-e961b768f849",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = csvaena_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e106bd9-7264-4ab7-88df-1e50e1e67256",
   "metadata": {},
   "outputs": [],
   "source": [
    "trace = csvaena_trainer.get_trace('test')\n",
    "print(-1 * trace.nodes['rec']['fn'].log_prob(test_set[:][0].cuda()).mean().item())\n",
    "plt.figure()\n",
    "plt.scatter(test_set[:][0].log1p().mean(dim=0), recons.log1p().mean(dim=0))\n",
    "plt.gca().axis('square')\n",
    "\n",
    "plt.figure()\n",
    "plt.scatter(test_set[:][0].log1p().var(dim=0), recons.log1p().var(dim=0))\n",
    "plt.gca().axis('square')\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e1af361-85d8-4ff8-8ae8-d5c10abf8bfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "z_s = reducer.fit_transform(z_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7f06347-e7df-4769-bab5-48f00d01044d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "\n",
    "plt.scatter(z_s[:, 0], z_s[:, 1], c=test_set[:][2].numpy(), cmap=cmap_ct, s=1)\n",
    "plt.axis('off')\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ae26630-d7b6-46d9-ade5-caa79fbe5bc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "\n",
    "plt.scatter(z_s[:, 0], z_s[:, 1], c=test_set[:][1].numpy(), cmap=cmap_trt, s=1)\n",
    "plt.axis('off')\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e15e5f0f-14e5-4f71-94ce-67785184e23f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "\n",
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:][2].numpy(), cmap=cmap_ct, s=1)\n",
    "plt.axis('off')\n",
    "plt.gca().set_aspect('equal')\n",
    "plt.ylim(-2,7)\n",
    "plt.xlim(-2,7)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e15a93e-0da3-4f6a-9237-8e2390366102",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "\n",
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:][1].numpy(), cmap=cmap_trt, s=1)\n",
    "plt.axis('off')\n",
    "plt.gca().set_aspect('equal')\n",
    "plt.ylim(-2,7)\n",
    "plt.xlim(-2,7)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6353a805-1e29-4a23-b93d-02af6fd528d4",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## CSVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bbe3526-cb39-4e18-bcaa-aa9a8a3c5168",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "csvae = ZINBCSVAE(2000, [1], latent_dim=10, w_dim=2, num_layers=1, hidden_dim=128, recon_weight=1, z_kl_weight=1e-4, w_kl_weight=1, adversarial_weight=1e2)\n",
    "csvae_trainer = AdversarialThresholdPyroTrainer(0, 50, 1, 1, csvae, train_loader, test_loader, opt.AdamW({\"lr\": 1e-3}))\n",
    "csvae_trainer.train()\n",
    "csvae_trainer.save('params/Kang/csvae_kang')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d11a1a53-f084-4bfb-81b9-b4eba74a603c",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = csvae_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bc625fe-5131-4e4d-97ba-dbe5c65ed37c",
   "metadata": {},
   "outputs": [],
   "source": [
    "trace = csvae_trainer.get_trace('test')\n",
    "print(-1 * trace.nodes['rec']['fn'].log_prob(test_set[:][0].cuda()).mean().item())\n",
    "plt.figure()\n",
    "plt.scatter(test_set[:][0].log1p().mean(dim=0), recons.log1p().mean(dim=0))\n",
    "plt.gca().axis('square')\n",
    "\n",
    "plt.figure()\n",
    "plt.scatter(test_set[:][0].log1p().var(dim=0), recons.log1p().var(dim=0))\n",
    "plt.gca().axis('square')\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de82c4c0-45bd-4540-9c00-33185dcd1968",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "z_s = reducer.fit_transform(z_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00c34de5-68a3-4e08-8224-e618b1196eed",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(z_s[:, 0], z_s[:, 1], c=test_set[:][2].numpy(), cmap=cmap_ct, s=1)\n",
    "plt.axis('off')\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1635c533-53d1-446e-a9a5-8105880e1a43",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(z_s[:, 0], z_s[:, 1], c=test_set[:][1].numpy(), cmap=cmap_trt, s=1)\n",
    "plt.axis('off')\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "785f7e7b-c507-46b2-9fd3-3ac5fbf67b43",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:][2].numpy(), cmap=cmap_ct, s=1)\n",
    "plt.axis('off')\n",
    "plt.ylim(-2,7)\n",
    "plt.xlim(-2,7)\n",
    "\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aab6362-2eeb-4307-a08b-b9be80542d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:][1].numpy(), cmap=cmap_trt, s=1)\n",
    "plt.axis('off')\n",
    "plt.ylim(-2,7)\n",
    "plt.xlim(-2,7)\n",
    "\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71d4c2a6-6aee-488e-b1cb-fe518fabf2aa",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## HCSVAE - No Adv."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49a089ad-59ac-47c4-b98e-566fa24cd2df",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "hcsvaena = ZINBHCSVAENA(2000, [1], latent_dim=10, w_dim=2, num_layers=1, hidden_dim=128, recon_weight=1, z_kl_weight=1e-4, w_kl_weight=1)\n",
    "hcsvaena_trainer = ThresholdPyroTrainer(0, 50, hcsvaena, train_loader, test_loader, opt.AdamW({\"lr\": 1e-3}))\n",
    "hcsvaena_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "199eff9d-dac6-4f75-a43f-50b40caae573",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = hcsvaena_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7231cf37-846d-4b79-a123-a207c2a7d41b",
   "metadata": {},
   "outputs": [],
   "source": [
    "trace = hcsvaena_trainer.get_trace('test')\n",
    "print(-1 * trace.nodes['rec']['fn'].log_prob(test_set[:][0].cuda()).mean().item())\n",
    "plt.figure()\n",
    "plt.scatter(test_set[:][0].log1p().mean(dim=0), recons.log1p().mean(dim=0))\n",
    "plt.gca().axis('square')\n",
    "\n",
    "plt.figure()\n",
    "plt.scatter(test_set[:][0].log1p().var(dim=0), recons.log1p().var(dim=0))\n",
    "plt.gca().axis('square')\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f603bb1c-48bc-40fc-98c0-95e28dec97d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "z_s = reducer.fit_transform(z_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd431f11-0262-40ae-94ea-f8594c39b3c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(z_s[:, 0], z_s[:, 1], c=test_set[:][2].numpy(), cmap=cmap_ct, s=1)\n",
    "plt.axis('off')\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faf9e9a8-2713-4966-ad46-32a69e1e74bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(z_s[:, 0], z_s[:, 1], c=test_set[:][1].numpy(), cmap=cmap_trt, s=1)\n",
    "plt.axis('off')\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17652034-b874-4212-9022-c32d42feeebd",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "\n",
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:][2].numpy(), cmap=cmap_ct, s=1)\n",
    "plt.axis('off')\n",
    "plt.ylim(-2,7)\n",
    "plt.xlim(-2,7)\n",
    "\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d0b4f39-cb3d-44e3-9ef2-17e76c69ffd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "\n",
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:][1].numpy(), cmap=cmap_trt, s=1)\n",
    "plt.axis('off')\n",
    "plt.ylim(-2,7)\n",
    "plt.xlim(-2,7)\n",
    "\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2dd88170-a0d9-457b-92a4-f37982b92c70",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## HCSVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ca1e85e-0f17-4f82-86ba-7e2dfb2b3515",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "hcsvae = ZINBHCSVAE(2000, [1], latent_dim=10, w_dim=2, num_layers=1, hidden_dim=128, recon_weight=1, z_kl_weight=1e-4, w_kl_weight=1 ,adversarial_weight=1e2)\n",
    "hcsvae_trainer = AdversarialThresholdPyroTrainer(0, 50, 1, 1 hcsvae, train_loader, test_loader, opt.AdamW({\"lr\": 1e-3}))\n",
    "hcsvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8741067c-3d9f-4b9a-9aec-34abeabfc33d",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = hcsvae_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "rho_s = preds['rho'][0].cpu()\n",
    "recons = preds['rec'][0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d710d4c-9c6a-4ffd-b16b-d6df39cf12ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "trace = hcsvae_trainer.get_trace('test')\n",
    "print(-1 * trace.nodes['rec']['fn'].log_prob(test_set[:][0].cuda()).mean().item())\n",
    "plt.figure()\n",
    "plt.scatter(test_set[:][0].log1p().mean(dim=0), recons.log1p().mean(dim=0))\n",
    "plt.gca().axis('square')\n",
    "\n",
    "plt.figure()\n",
    "plt.scatter(test_set[:][0].log1p().var(dim=0), recons.log1p().var(dim=0))\n",
    "plt.gca().axis('square')\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38e1c7d5-af9b-4c77-aa4a-3166ee16b25e",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "z_s = reducer.fit_transform(z_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8926430a-b874-426c-8aa5-9e5d069e866f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(z_s[:, 0], z_s[:, 1], c=test_set[:][2].numpy(), cmap=cmap_ct, s=1)\n",
    "plt.axis('off')\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2148bbc-5255-4a42-923f-ea25eca4689b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(z_s[:, 0], z_s[:, 1], c=test_set[:][1].numpy(), cmap=cmap_trt, s=1)\n",
    "plt.axis('off')\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eb58d88-6326-4f7f-899c-e0aab5b98c00",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:][2].numpy(), cmap=cmap_ct, s=1)\n",
    "plt.axis('off')\n",
    "plt.ylim(-2,7)\n",
    "plt.xlim(-2,7)\n",
    "\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4f2d837-a2dc-4421-92a7-92538ccc6f54",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:][1].numpy(), cmap=cmap_trt, s=1)\n",
    "plt.axis('off')\n",
    "plt.ylim(-2,7)\n",
    "plt.xlim(-2,7)\n",
    "\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a470272-ff76-4cd1-b1c9-59d1d4fbb367",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## DIVA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cfe8dc4-c3a4-4f56-829c-17d949bc3ae8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "diva = ZINBDIVA(2000, [1], latent_dim=10, w_dim=2, num_layers=1, hidden_dim=128, recon_weight=1, kl_weight=1e-4)\n",
    "diva_trainer = ThresholdPyroTrainer(0, 50, diva, train_loader, test_loader, opt.AdamW({\"lr\": 1e-3}))\n",
    "diva_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d41e1da6-9e5b-4b2f-9faa-f6a5d3864ad0",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = diva_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef068969-d557-487b-8c8b-0cb6c827190a",
   "metadata": {},
   "outputs": [],
   "source": [
    "trace = diva_trainer.get_trace('test')\n",
    "print(-1 * trace.nodes['rec']['fn'].log_prob(test_set[:][0].cuda()).mean().item())\n",
    "plt.figure()\n",
    "plt.scatter(test_set[:][0].log1p().mean(dim=0), recons.log1p().mean(dim=0))\n",
    "plt.gca().axis('square')\n",
    "\n",
    "plt.figure()\n",
    "plt.scatter(test_set[:][0].log1p().var(dim=0), recons.log1p().var(dim=0))\n",
    "plt.gca().axis('square')\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ebb32b2-2df4-4639-ac3f-51c9a9aab056",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "z_s = reducer.fit_transform(z_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "765ad3b3-720c-4bf9-bcce-a0ad714796fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(z_s[:, 0], z_s[:, 1], c=test_set[:][2].numpy(), cmap=cmap_ct, s=1)\n",
    "plt.axis('off')\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "388e90ee-9e7e-4f4b-8044-b02b0a2c20d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(z_s[:, 0], z_s[:, 1], c=test_set[:][1].numpy(), cmap=cmap_trt, s=1)\n",
    "plt.axis('off')\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab1ee95-da3b-4862-869f-5766bc30f7b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:][2].numpy(), cmap=cmap_ct, s=1)\n",
    "plt.axis('off')\n",
    "plt.ylim(-7,7)\n",
    "plt.xlim(-7,7)\n",
    "\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f190d035-c11f-4606-8513-713ebe9721c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:][1].numpy(), cmap=cmap_trt, s=1)\n",
    "plt.axis('off')\n",
    "plt.ylim(-7,7)\n",
    "plt.xlim(-7,7)\n",
    "\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd67328a-045c-4f7b-a2d1-87551b870c3d",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## CCVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f588a4be-0abf-4268-890c-0972c0150ec9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "ccvae = ZINBCCVAE(2000, [1], latent_dim=10, w_dim=2, num_layers=1, hidden_dim=128, recon_weight=1, kl_weight=1e-4)\n",
    "ccvae_trainer = ThresholdPyroTrainer(0, 50, ccvae, train_loader, test_loader, opt.AdamW({\"lr\": 1e-3}))\n",
    "ccvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d655d3d-3e9b-463f-bbbb-0db7eb5a5dad",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = ccvae_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons = preds['rec'][0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "145c545f-2112-49c4-9251-a9b4be81366a",
   "metadata": {},
   "outputs": [],
   "source": [
    "trace = ccvae_trainer.get_trace('test')\n",
    "print(-1 * trace.nodes['rec']['fn'].log_prob(test_set[:][0].cuda()).mean().item())\n",
    "plt.figure()\n",
    "plt.scatter(test_set[:][0].log1p().mean(dim=0), recons.log1p().mean(dim=0))\n",
    "plt.gca().axis('square')\n",
    "\n",
    "plt.figure()\n",
    "plt.scatter(test_set[:][0].log1p().var(dim=0), recons.log1p().var(dim=0))\n",
    "plt.gca().axis('square')\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7ea2b19-65c0-4ffd-bb25-e1600203e026",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "z_s = reducer.fit_transform(z_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3906633e-b8f3-4742-859c-d9ec06e73581",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(z_s[:, 0], z_s[:, 1], c=test_set[:][2].numpy(), cmap=cmap_ct, s=1)\n",
    "plt.axis('off')\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aa31ec8-b086-4ce2-8104-b9d9b96e41a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(z_s[:, 0], z_s[:, 1], c=test_set[:][1].numpy(), cmap=cmap_trt, s=1)\n",
    "plt.axis('off')\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8e5d509-494a-4fdb-81e7-1c79718e535a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:][2].numpy(), cmap=cmap_ct, s=1)\n",
    "plt.axis('off')\n",
    "plt.ylim(-7,7)\n",
    "plt.xlim(-7,7)\n",
    "\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "243a4c7a-4a21-4fca-8afa-9e9fbd9930c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:][1].numpy(), cmap=cmap_trt, s=1)\n",
    "plt.axis('off')\n",
    "plt.ylim(-7,7)\n",
    "plt.xlim(-7,7)\n",
    "\n",
    "plt.gca().set_aspect(1/plt.gca().get_data_ratio())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43130ba1-1257-45f6-97be-6ec4c46b7cd9",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## DISCoVeR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23c27591-36fc-4367-8fa7-b01b88ecf75a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "pyro.util.set_rng_seed(42)\n",
    "\n",
    "dlvae = ZINBDLVAE(2000, [1], latent_dim=10, w_dim=10, num_layers=0, hidden_dim=128, recon_weight=9e-1, recon_weight_z=1e-1, w_kl_weight=1e-4, z_kl_weight=1e-4, adversarial_weight=1e2, learnable_prior=False)\n",
    "dlvae_trainer = AdversarialThresholdPyroTrainer(0, 50, 1, 1, dlvae, train_loader, test_loader, opt.AdamW({\"lr\": 1e-3}))\n",
    "dlvae_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8676fb38-3e33-4de6-983f-7610975ac38a",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = dlvae_trainer.get_variables('test')\n",
    "z_s = preds['z'][0].cpu()\n",
    "w_s = preds['w'][0].cpu()\n",
    "recons_w = preds['rec_w'][0].cpu()\n",
    "recons_z = preds['rec_z'][0].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25eb7a42-8de1-4c4e-8735-7d8fba068a43",
   "metadata": {},
   "outputs": [],
   "source": [
    "trace = dlvae_trainer.get_trace('test')\n",
    "print(-1 * trace.nodes['rec_w']['fn'].log_prob(test_set[:][0].cuda()).mean().item())\n",
    "print(-1 * trace.nodes['rec_z']['fn'].log_prob(test_set[:][0].cuda()).mean().item())\n",
    "\n",
    "plt.figure()\n",
    "plt.scatter(test_set[:][0].log1p().mean(dim=0), recons_w.log1p().mean(dim=0))\n",
    "plt.scatter(test_set[:][0].log1p().mean(dim=0), recons_z.log1p().mean(dim=0))\n",
    "\n",
    "plt.gca().axis('square')\n",
    "\n",
    "plt.figure()\n",
    "plt.scatter(test_set[:][0].log1p().var(dim=0), recons_w.log1p().var(dim=0))\n",
    "plt.scatter(test_set[:][0].log1p().var(dim=0), recons_z.log1p().var(dim=0))\n",
    "plt.gca().axis('square')\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3a0a707-4da7-43b5-9047-672284ac7add",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "umap_recon_zs = reducer.fit_transform(torch.log1p(recons_z))\n",
    "z_s = reducer.fit_transform(z_s)\n",
    "w_s = reducer.fit_transform(w_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "247d633b-79aa-4dfa-bab7-4349d7a74a6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,5))\n",
    "plt.scatter(z_s[:, 0], z_s[:, 1], c=test_set[:][2].numpy(), cmap=cmap_ct, s=1)\n",
    "plt.gca().set_aspect('equal')\n",
    "plt.yticks([])\n",
    "plt.xticks([])\n",
    "plt.gca().spines[['right', 'top']].set_visible(False)\n",
    "\n",
    "plt.ylabel('UMAP 2', font=font, fontsize=24)\n",
    "plt.xlabel('UMAP 1', font=font, fontsize=24)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0883fa3-29a1-4c90-9872-a55744112d4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,5))\n",
    "plt.scatter(z_s[:, 0], z_s[:, 1], c=test_set[:][1].numpy(), cmap=cmap_trt, s=1)\n",
    "plt.gca().set_aspect('equal')\n",
    "plt.yticks([])\n",
    "plt.xticks([])\n",
    "plt.gca().spines[['right', 'top']].set_visible(False)\n",
    "\n",
    "plt.ylabel('UMAP 2', font=font, fontsize=24)\n",
    "plt.xlabel('UMAP 1', font=font, fontsize=24)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4bd3578-67b2-4d35-bf6a-eba04aceb44a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,5))\n",
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:][2].numpy(), cmap=cmap_ct, s=1)\n",
    "plt.gca().set_aspect('equal')\n",
    "plt.yticks([])\n",
    "plt.xticks([])\n",
    "plt.gca().spines[['right', 'top']].set_visible(False)\n",
    "\n",
    "plt.ylabel('UMAP 2', font=font, fontsize=24)\n",
    "plt.xlabel('UMAP 1', font=font, fontsize=24)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbe85b51-2dcf-475e-aac3-a5e2ad15b1bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,5))\n",
    "plt.scatter(w_s[:, 0], w_s[:, 1], c=test_set[:][1].numpy(), cmap=cmap_trt, s=1)\n",
    "plt.gca().set_aspect('equal')\n",
    "plt.yticks([])\n",
    "plt.xticks([])\n",
    "plt.gca().spines[['right', 'top']].set_visible(False)\n",
    "\n",
    "plt.ylabel('UMAP 2', font=font, fontsize=24)\n",
    "plt.xlabel('UMAP 1', font=font, fontsize=24)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
