{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import statements\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import yaml\n",
    "\n",
    "from models import beta_vae\n",
    "from experiment import VAEXperiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load CelebA\n",
    "\n",
    "to_tensor = torchvision.transforms.ToTensor()\n",
    "downsize = torchvision.transforms.Resize((64, 64))\n",
    "composed_transform = torchvision.transforms.Compose([downsize, to_tensor])\n",
    "root = \"\" # path to CelebA dataset\n",
    "trainset = torchvision.datasets.CelebA(root=root, split='train', download=True, transform=composed_transform)\n",
    "trainset_abridged = torch.utils.data.Subset(trainset, range(2000)) # 2000 images\n",
    "batch_size = 16\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize beta_vae\n",
    "\n",
    "in_channels = 3\n",
    "latent_dim = 32\n",
    "loss_type = 'H'\n",
    "beta = 10.0\n",
    "\n",
    "model = beta_vae.BetaVAE(in_channels=in_channels, latent_dim=latent_dim, loss_type=loss_type, beta=beta).to(device)\n",
    "\n",
    "# load params\n",
    "\n",
    "param_path = \"/PyTorch-VAE/configs/bhvae.yaml\"\n",
    "with open(param_path, 'r') as f:\n",
    "    config = yaml.load(f, Loader=yaml.FullLoader)\n",
    "\n",
    "params = config['exp_params']\n",
    "\n",
    "vae = VAEXperiment(vae_model=model, params=params)\n",
    "\n",
    "# Load checkpoint\n",
    "\n",
    "checkpoint_path = \"\" # fill in path to checkpoint\n",
    "\n",
    "# Load weights from state_dict\n",
    "\n",
    "model = vae.load_from_checkpoint(checkpoint_path, vae_model=model, params=params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw N random points from training set\n",
    "\n",
    "X, _ = next(iter(trainloader))\n",
    "\n",
    "X = X.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show(img):\n",
    "    npimg = img.numpy()\n",
    "    plt.figure(figsize=(20,5))\n",
    "    # no ticks\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize first 10 elements of X on subplots\n",
    "\n",
    "X_viz = X[:8].cpu()\n",
    "\n",
    "# Visualize X_viz on torch grid\n",
    "\n",
    "grid = torchvision.utils.make_grid(X_viz.reshape(-1,3,64,64), nrow=10)\n",
    "show(grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize reconstructions\n",
    "\n",
    "X_rec = model(X[:8])[0].detach().cpu()\n",
    "\n",
    "# Visualize X_rec on torch grid\n",
    "\n",
    "grid = torchvision.utils.make_grid(X_rec.reshape(-1,3,64,64), nrow=10)\n",
    "show(grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute latent for X[0]\n",
    "\n",
    "mu, logvar = model.model.encode(X[0].unsqueeze(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Traverse latent in [-3,3] for dimension d\n",
    "\n",
    "d_list = [1, 2, 11]\n",
    "n_traversals = 10\n",
    "z = mu\n",
    "\n",
    "# Replace d-th dimension of z with traversal of [-3,3] for each d in d_list\n",
    "\n",
    "z_traversals = torch.zeros(n_traversals*len(d_list), latent_dim).to(device)\n",
    "for i, d in enumerate(d_list):\n",
    "    for j, val in enumerate(torch.linspace(-3, 3, n_traversals)):\n",
    "        z_traversals[i*n_traversals+j] = z\n",
    "        z_traversals[i*n_traversals+j][d] = val\n",
    "\n",
    "# Decode z_traversals\n",
    "\n",
    "X_traversals = model.model.decode(z_traversals).detach().cpu()\n",
    "\n",
    "# Visualize X_traversals on torch grid\n",
    "\n",
    "grid = torchvision.utils.make_grid(X_traversals.reshape(-1,3,64,64), nrow=10)\n",
    "show(grid)\n",
    "\n",
    "# Save figure\n",
    "\n",
    "plt.savefig('../results/beta_vae_traversals.png', bbox_inches='tight', dpi=300)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ldm4",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
