{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 535,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = str(4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 536,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "h_dim 25088\n",
      "Loading checkpoint from vae_logs/cifar/conv_optimal_constant/checkpoint_10.pt\n"
     ]
    }
   ],
   "source": [
    "from __future__ import print_function\n",
    "import torch\n",
    "import numpy as np\n",
    "from torch import nn, optim\n",
    "from torch.nn import functional as F\n",
    "from torchvision.utils import save_image\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import torch.distributions\n",
    "\n",
    "from vae_models.vae_builder import VAEBuilder\n",
    "\n",
    "\n",
    "# MNIST MI expreiments\n",
    "# python3 -m pdb -c continue train_latents_vae.py --log_dir MNIST/nolb_beta0001_MI --learn_beta=0 --sigma=0.0223\n",
    "\n",
    "# MNIST\n",
    "arguments = ['--log_dir', 'MNIST/nolb_beta0001_MI', '--learn_beta', str(0), '--sigma', str(0.0223)]\n",
    "\n",
    "# CIFAR\n",
    "arguments = ['--dataset', 'CIFAR',\n",
    "             '--log_dir', 'cifar/nolb_beta10',\n",
    "             '--learn_beta', '0',\n",
    "             '--sigma', '0.0223',\n",
    "            '--n_filters', '128']\n",
    "\n",
    "arguments = ['--dataset', 'CIFAR',\n",
    "             '--log_dir', 'cifar/conv_optimal_constant',\n",
    "             '--sigma_mode', 'optimal_constant',\n",
    "            '--n_filters', '128']\n",
    "\n",
    "\n",
    "builder = VAEBuilder()\n",
    "args, device = builder.get_arguments(arguments)\n",
    "train_loader, test_loader = builder.get_dataset()\n",
    "summary_writer = builder.get_summary_writer(purge_step=None)\n",
    "model = builder.build_vae(device, args)\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "builder.load_initial_checkpoint(model)\n",
    "\n",
    "\n",
    "def logavgexp(tensor, dim):\n",
    "    \"\"\" Computes log-average-exponent, similar to logsumexp but with average instead of sum\n",
    "    log (avg) = log (sum / N) = log (sum) - log (N) \"\"\"\n",
    "\n",
    "    return torch.logsumexp(tensor, dim) - np.log(tensor.shape[dim])\n",
    "\n",
    "\n",
    "def log_ratio(sample, encoder, prior):\n",
    "    log_q_posterior = encoder.log_prob(sample[None])\n",
    "    log_q_marginal = logavgexp(log_q_posterior, 0)\n",
    "    \n",
    "    log_p = prior.log_prob(sample)\n",
    "    \n",
    "    ratio = (log_q_marginal - log_p).sum()\n",
    "    return ratio\n",
    "\n",
    "\n",
    "\"\"\" This script is used to store latents produced by a VAE over the whole dataset \"\"\"\n",
    "data = torch.load('vae_logs/{}/latents.pt'.format(args.log_dir))\n",
    "\n",
    "mu = data['mu']\n",
    "logvar = data['logvar']\n",
    "sigma = torch.exp(0.5 * logvar)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Posterior - prior KL estimates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 537,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exact divergence 60.20174026489258\n",
      "MC estimate 60.201148986816406\n",
      "encoder entropy -36.028079986572266\n"
     ]
    }
   ],
   "source": [
    "# Exact divergence (no sampling from encoder, sampling from dataset)\n",
    "KL_posterior_prior = model.kl_divergence_unit(mu, logvar) / mu.shape[0]\n",
    "print('Exact divergence', KL_posterior_prior.item())\n",
    "\n",
    "encoder = torch.distributions.Normal(mu, sigma)\n",
    "prior = torch.distributions.Normal(torch.zeros_like(mu)[[0]], torch.ones_like(sigma)[[0]])\n",
    "\n",
    "# MC estimate of posterior - prior KL\n",
    "samples = encoder.sample()\n",
    "KL_pp = (encoder.log_prob(samples) - prior.log_prob(samples)).sum(1).mean()\n",
    "print('MC estimate', KL_pp.item())\n",
    "\n",
    "# Encoder entropy\n",
    "print('encoder entropy', encoder.entropy().sum(1).mean().item())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MC estimate of marginal - prior KL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 538,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KL divergence between the marginal and the prior 1.7642931938171387\n"
     ]
    }
   ],
   "source": [
    "## MC estimate of marginal - prior KL\n",
    "latents_file = 'vae_logs/{}/kl_marginal_prior_list.pt'.format(args.log_dir)\n",
    "\n",
    "samples = encoder.sample()\n",
    "n_samples = samples.shape[0]\n",
    "n_samples = 1000\n",
    "use_caching = False\n",
    "\n",
    "if os.path.exists(latents_file) and use_caching:\n",
    "    KL_mp_list = torch.load(latents_file)\n",
    "else:\n",
    "    KL_mp_list = torch.stack([log_ratio(samples[i], encoder, prior) for i in range(n_samples)], 0)\n",
    "    if n_samples == samples.shape[0]:\n",
    "        torch.save(KL_mp_list, latents_file)\n",
    "\n",
    "KL_mp = KL_mp_list.mean()\n",
    "print('KL divergence between the marginal and the prior', KL_mp.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 539,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1.7643, device='cuda:0')\n",
      "tensor(0.0123, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "print(KL_mp)\n",
    "\n",
    "# Standard error of the mean of the KL estimate\n",
    "print(KL_mp_list.std() / np.sqrt(samples.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 540,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Entropy of the marginal 22.50543785095215\n",
      "Cross-Entropy between marginal and prior 24.269729614257812\n",
      "Entropy of the prior 28.37877066409345\n"
     ]
    }
   ],
   "source": [
    "# Entropy of the marginal and prior distributions\n",
    "log_p = prior.log_prob(samples[:n_samples]).sum(1)\n",
    "log_q_marginal = KL_mp_list + log_p\n",
    "\n",
    "marginal_entropy = -log_q_marginal.mean().item()\n",
    "print('Entropy of the marginal', -log_q_marginal.mean().item()) \n",
    "print('Cross-Entropy between marginal and prior', -log_p.mean().item())\n",
    "print('Entropy of the prior', (np.log(2 * np.pi * np.e * 1) / 2) * samples.shape[1])\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py3] *",
   "language": "python",
   "name": "conda-env-py3-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
