{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Example: Conditional Variational Autoencoder in Flax\n\nThis example trains a *Conditional Variational Autoencoder* (CVAE) [1] on the MNIST data\nusing Flax' neural network API. The implementation can be found here:\nhttps://github.com/pyro-ppl/numpyro/tree/master/examples/cvae-flax\n\nThe model is a port of Pyro's excellent CVAE example which describes the model as well as the data in detail:\nhttps://pyro.ai/examples/cvae.html\n\nThe model first trains a baseline to predict an entire MNIST image from a single quadrant of it\n(i.e., input is one quadrant of an image, output is the entire image (not the other three quadrants)).\nThen, in a second model, the generation/prior/recognition nets of the CVAE are trained while keeping the model\nparameters of the baseline fixed/frozen. We use Optax' `multi_transform` to apply different gradient transformations\nto the trainable parameters and the frozen parameters.\n\n\n<img src=\"file://../_static/img/examples/cvae.png\" align=\"center\">\n\n**References:**\n\n    1. Kihyuk Sohn, Xinchen Yan, Honglak Lee (2015), \"Learning Structured Output Representation using Deep\n       Conditional Generative Models\n       (https://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generative-models)\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}