{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Variational Autoencoder\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import argparse\nimport inspect\nimport os\nimport time\n\nimport matplotlib.pyplot as plt\n\nfrom jax import jit, lax, random\nfrom jax.example_libraries import stax\nimport jax.numpy as jnp\nfrom jax.random import PRNGKey\n\nimport numpyro\nfrom numpyro import optim\nimport numpyro.distributions as dist\nfrom numpyro.examples.datasets import MNIST, load_dataset\nfrom numpyro.infer import SVI, Trace_ELBO\n\nRESULTS_DIR = os.path.abspath(\n    os.path.join(os.path.dirname(inspect.getfile(lambda: None)), \".results\")\n)\nos.makedirs(RESULTS_DIR, exist_ok=True)\n\n\ndef encoder(hidden_dim, z_dim):\n    return stax.serial(\n        stax.Dense(hidden_dim, W_init=stax.randn()),\n        stax.Softplus,\n        stax.FanOut(2),\n        stax.parallel(\n            stax.Dense(z_dim, W_init=stax.randn()),\n            stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp),\n        ),\n    )\n\n\ndef decoder(hidden_dim, out_dim):\n    return stax.serial(\n        stax.Dense(hidden_dim, W_init=stax.randn()),\n        stax.Softplus,\n        stax.Dense(out_dim, W_init=stax.randn()),\n        stax.Sigmoid,\n    )\n\n\ndef model(batch, hidden_dim=400, z_dim=100):\n    batch = jnp.reshape(batch, (batch.shape[0], -1))\n    batch_dim, out_dim = jnp.shape(batch)\n    decode = numpyro.module(\"decoder\", decoder(hidden_dim, out_dim), (batch_dim, z_dim))\n    with numpyro.plate(\"batch\", batch_dim):\n        z = numpyro.sample(\"z\", dist.Normal(0, 1).expand([z_dim]).to_event(1))\n        img_loc = decode(z)\n        return numpyro.sample(\"obs\", dist.Bernoulli(img_loc).to_event(1), obs=batch)\n\n\ndef guide(batch, hidden_dim=400, z_dim=100):\n    batch = jnp.reshape(batch, (batch.shape[0], -1))\n    batch_dim, out_dim = jnp.shape(batch)\n    encode = numpyro.module(\"encoder\", encoder(hidden_dim, z_dim), (batch_dim, out_dim))\n    z_loc, z_std = encode(batch)\n    with numpyro.plate(\"batch\", batch_dim):\n        return numpyro.sample(\"z\", dist.Normal(z_loc, z_std).to_event(1))\n\n\n@jit\ndef binarize(rng_key, batch):\n    return random.bernoulli(rng_key, batch).astype(batch.dtype)\n\n\ndef main(args):\n    encoder_nn = encoder(args.hidden_dim, args.z_dim)\n    decoder_nn = decoder(args.hidden_dim, 28 * 28)\n    adam = optim.Adam(args.learning_rate)\n    svi = SVI(\n        model, guide, adam, Trace_ELBO(), hidden_dim=args.hidden_dim, z_dim=args.z_dim\n    )\n    rng_key = PRNGKey(0)\n    train_init, train_fetch = load_dataset(\n        MNIST, batch_size=args.batch_size, split=\"train\"\n    )\n    test_init, test_fetch = load_dataset(\n        MNIST, batch_size=args.batch_size, split=\"test\"\n    )\n    num_train, train_idx = train_init()\n    rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3)\n    sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0])\n    svi_state = svi.init(rng_key_init, sample_batch)\n\n    @jit\n    def epoch_train(svi_state, rng_key, train_idx):\n        def body_fn(i, val):\n            loss_sum, svi_state = val\n            rng_key_binarize = random.fold_in(rng_key, i)\n            batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])\n            svi_state, loss = svi.update(svi_state, batch)\n            loss_sum += loss\n            return loss_sum, svi_state\n\n        return lax.fori_loop(0, num_train, body_fn, (0.0, svi_state))\n\n    @jit\n    def eval_test(svi_state, rng_key, test_idx):\n        def body_fun(i, loss_sum):\n            rng_key_binarize = random.fold_in(rng_key, i)\n            batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0])\n            # FIXME: does this lead to a requirement for an rng_key arg in svi_eval?\n            loss = svi.evaluate(svi_state, batch) / len(batch)\n            loss_sum += loss\n            return loss_sum\n\n        loss = lax.fori_loop(0, num_test, body_fun, 0.0)\n        loss = loss / num_test\n        return loss\n\n    def reconstruct_img(epoch, rng_key):\n        img = test_fetch(0, test_idx)[0][0]\n        plt.imsave(\n            os.path.join(RESULTS_DIR, \"original_epoch={}.png\".format(epoch)),\n            img,\n            cmap=\"gray\",\n        )\n        rng_key_binarize, rng_key_sample = random.split(rng_key)\n        test_sample = binarize(rng_key_binarize, img)\n        params = svi.get_params(svi_state)\n        z_mean, z_var = encoder_nn[1](\n            params[\"encoder$params\"], test_sample.reshape([1, -1])\n        )\n        z = dist.Normal(z_mean, z_var).sample(rng_key_sample)\n        img_loc = decoder_nn[1](params[\"decoder$params\"], z).reshape([28, 28])\n        plt.imsave(\n            os.path.join(RESULTS_DIR, \"recons_epoch={}.png\".format(epoch)),\n            img_loc,\n            cmap=\"gray\",\n        )\n\n    for i in range(args.num_epochs):\n        rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split(\n            rng_key, 4\n        )\n        t_start = time.time()\n        num_train, train_idx = train_init()\n        _, svi_state = epoch_train(svi_state, rng_key_train, train_idx)\n        rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3)\n        num_test, test_idx = test_init()\n        test_loss = eval_test(svi_state, rng_key_test, test_idx)\n        reconstruct_img(i, rng_key_reconstruct)\n        print(\n            \"Epoch {}: loss = {} ({:.2f} s.)\".format(\n                i, test_loss, time.time() - t_start\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    assert numpyro.__version__.startswith(\"0.13.2\")\n    parser = argparse.ArgumentParser(description=\"parse args\")\n    parser.add_argument(\n        \"-n\", \"--num-epochs\", default=15, type=int, help=\"number of training epochs\"\n    )\n    parser.add_argument(\n        \"-lr\", \"--learning-rate\", default=1.0e-3, type=float, help=\"learning rate\"\n    )\n    parser.add_argument(\"-batch-size\", default=128, type=int, help=\"batch size\")\n    parser.add_argument(\"-z-dim\", default=50, type=int, help=\"size of latent\")\n    parser.add_argument(\n        \"-hidden-dim\",\n        default=400,\n        type=int,\n        help=\"size of hidden layer in encoder/decoder networks\",\n    )\n    args = parser.parse_args()\n    main(args)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}