{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bd2af8b-8528-42ba-a118-63711bbab8b8",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../src')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2a19f8f",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import distrax\n",
    "import equinox as eqx\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import optax\n",
    "import seaborn\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "import flowee\n",
    "import nn\n",
    "from distribution_embedding import FlowEmbedding, train\n",
    "\n",
    "np.random.seed(0)\n",
    "key = jax.random.PRNGKey(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fc72d0d",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "source": [
    "## Data Generation\n",
    "\n",
    "We create a family of distributions which correspond to parameterized donuts in $\\mathbb{R}^2$.  Each \"donut\" distribution is parameterized by a randomly sampled mean, radius, and width."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbabbe2c-4b6f-42eb-bd8a-4836b76c9b57",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def generate_data(num_dists, num_samples, mean=None, radius=None, width=None):\n",
    "    if mean is None:\n",
    "        mean = np.random.normal(0.0, 1.0, size=(num_dists, 1, 2))\n",
    "    if radius is None:\n",
    "        radius = np.random.uniform(0.5, 2.0, size=(num_dists, 1, 1))\n",
    "    if width is None:\n",
    "        width = np.random.uniform(0.01, 0.75, size=(num_dists, 1, 1))\n",
    "\n",
    "    a = np.random.normal(0.0, 1.0, size=(num_dists, num_samples, 2))\n",
    "    b = np.random.uniform(-0.5, 0.5, size=(num_dists, num_samples, 1))\n",
    "\n",
    "    return mean + radius * (a / np.linalg.norm(a, keepdims=True, axis=-1)) * (1.0 + width * b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f20ab3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure()\n",
    "\n",
    "# Parameters used for the figures in the paper\n",
    "mean = np.array([[[-2.7, 0.5]], [[0.25, 1.2]], [[0.0, -0.5]]])\n",
    "radius = np.array([[[2.0]], [[1.4]], [[1.6]]])\n",
    "width = np.array([[[0.45]], [[0.15]], [[0.3]]])\n",
    "\n",
    "data = generate_data(3, 768, mean, radius, width)\n",
    "for i in range(3):\n",
    "    plt.scatter(data[i, :, 0], data[i, :, 1], s=20.0, marker='.')\n",
    "\n",
    "plt.xlabel(r'$x_1$')\n",
    "plt.ylabel(r'$x_2$')\n",
    "plt.grid()\n",
    "\n",
    "# Save these limits so that we can use them for the second figure\n",
    "xlims = plt.gca().get_xlim()\n",
    "ylims = plt.gca().get_ylim()\n",
    "\n",
    "plt.savefig('donuts-example.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74a80c4d-aa95-4205-8afe-260c46129dc1",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "source": [
    "## Embedding Model using Normalizing Flows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2937128-d916-40f3-96d4-1091a224a481",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "input_dim = 2\n",
    "output_dim = 2\n",
    "embedding_dim = 8\n",
    "hidden_size = 64\n",
    "num_layers = 3\n",
    "num_params = 5\n",
    "\n",
    "key, embedding_key = jax.random.split(key, 2)\n",
    "embedding_network = eqx.nn.MLP(input_dim, embedding_dim, hidden_size, num_layers, key=embedding_key)\n",
    "\n",
    "hidden_size = 32\n",
    "num_layers = 3\n",
    "num_coupling_layers = 8\n",
    "\n",
    "key, *flow_keys = jax.random.split(key, 1 + num_coupling_layers)\n",
    "flow_network = flowee.Sequential([\n",
    "    flowee.Coupling(\n",
    "        flowee.create_mask((input_dim,), (1,)),\n",
    "        nn.MLP(input_dim + embedding_dim, input_dim * num_params, hidden_size, num_layers, key=flow_keys[i]),\n",
    "        flowee.ParameterizedNLSq((input_dim,)),\n",
    "        dual=True\n",
    "    )\n",
    "    for i in range(num_coupling_layers)\n",
    "])\n",
    "flow_network.add_prior(distrax.Normal(0.0, 1.0), (input_dim,))\n",
    "\n",
    "model = FlowEmbedding(embedding_network, flow_network)\n",
    "\n",
    "print(\n",
    "    'Model Parameters:', sum(\n",
    "        l.size for l in jax.tree_util.tree_leaves(eqx.filter(model, eqx.is_inexact_array))\n",
    "    )\n",
    ")\n",
    "\n",
    "optimizer = optax.adam(1e-3)\n",
    "optimizer_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2f15a5c-694c-4d2e-a33d-85a0fbb8783a",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "705764ac-f4b8-43e1-8f83-1f550a536587",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "losses = []\n",
    "num_dists = 32\n",
    "num_samples = 128\n",
    "\n",
    "progress_bar = tqdm(range(30000))\n",
    "for i in progress_bar:\n",
    "    key, train_key = jax.random.split(key, 2)\n",
    "\n",
    "    data = generate_data(num_dists, num_samples)\n",
    "    model, optimizer_state, loss_value = train(\n",
    "        model, optimizer, optimizer_state, data, key=train_key\n",
    "    )\n",
    "\n",
    "    losses.append(loss_value)\n",
    "\n",
    "    if i % 100 == 0:\n",
    "        progress_bar.set_postfix({'loss': np.mean(losses)})\n",
    "\n",
    "# Plot a smoothed sequence of loss values\n",
    "plt.plot(np.convolve(losses, np.ones(16), 'valid') / 16)\n",
    "plt.grid()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82359482-7162-49fe-83d2-186ecd4d3864",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04775e20-9852-44ab-b403-aa2030c543b1",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def evaluate(model, num_dists, num_embedding_samples, num_generated_samples, key):\n",
    "    data = generate_data(num_dists, num_embedding_samples + num_generated_samples)\n",
    "\n",
    "    def generate_samples(data, key):\n",
    "        return model.generate(\n",
    "            model.embed(data[:, :num_embedding_samples]),\n",
    "            nsamples=num_generated_samples, key=key\n",
    "        )\n",
    "\n",
    "    def rmse(x, y, axis=-1):\n",
    "        return np.mean(np.sqrt(np.sum((x - y) ** 2, axis=axis)))\n",
    "\n",
    "    key, *generate_keys = jax.random.split(key, 1 + num_dists)\n",
    "    samp = jax.vmap(generate_samples)(data, jnp.array(generate_keys))\n",
    "\n",
    "    data_mean = np.mean(data, keepdims=True, axis=1)\n",
    "    samp_mean = np.mean(samp, keepdims=True, axis=1)\n",
    "\n",
    "    diff_mean = rmse(data_mean, samp_mean)\n",
    "    print(f'Mean Error: {diff_mean:.5f}')\n",
    "\n",
    "    data_radius = np.linalg.norm(data - data_mean, axis=2)\n",
    "    samp_radius = np.linalg.norm(samp - samp_mean, axis=2)\n",
    "\n",
    "    data_radius_mean = np.mean(data_radius, keepdims=True, axis=1)\n",
    "    samp_radius_mean = np.mean(samp_radius, keepdims=True, axis=1)\n",
    "\n",
    "    diff_radius = rmse(data_radius_mean, samp_radius_mean)\n",
    "    print(f'Radius Error: {diff_radius:.5f}')\n",
    "\n",
    "    data_radius_diff = data_radius - data_radius_mean\n",
    "    samp_radius_diff = samp_radius - samp_radius_mean\n",
    "\n",
    "    data_radius_diff_std = np.std(data_radius_diff, keepdims=True, axis=1)\n",
    "    samp_radius_diff_std = np.std(samp_radius_diff, keepdims=True, axis=1)\n",
    "\n",
    "    diff_width = rmse(data_radius_diff_std, samp_radius_diff_std)\n",
    "    print(f'Width Error: {diff_width:.5f}')\n",
    "\n",
    "    diff_width = rmse(\n",
    "        data_radius_diff_std / data_radius_mean,\n",
    "        samp_radius_diff_std / samp_radius_mean\n",
    "    )\n",
    "    print(f'Width / Radius Error: {diff_width:.5f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42e7516f-9069-48af-93b1-334a28a21d19",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "num_dists = 1024\n",
    "num_embedding_samples = 128\n",
    "num_generated_samples = 1024\n",
    "\n",
    "key, evaluate_key = jax.random.split(key, 2)\n",
    "evaluate(model, num_dists, num_embedding_samples, num_generated_samples, evaluate_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6689e19d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_gradient_cmap(color):\n",
    "    if isinstance(color, str):\n",
    "        color = mpl.colors.to_rgb(color)\n",
    "\n",
    "    cmap_name = f'gradient_from_{color}'\n",
    "    colors = [(1.0, 1.0, 1.0, 0.0), (*color, 1.0)]\n",
    "\n",
    "    cmap = mpl.colors.LinearSegmentedColormap.from_list(cmap_name, colors)\n",
    "\n",
    "    return cmap\n",
    "\n",
    "true_samples = generate_data(3, 256, mean, radius, width)\n",
    "\n",
    "grid_size = 512\n",
    "x = np.linspace(*xlims, grid_size)\n",
    "y = np.linspace(*ylims, grid_size)\n",
    "x, y = np.meshgrid(x, y)\n",
    "grid = np.stack([x.ravel(), y.ravel()], axis=-1)\n",
    "\n",
    "color_cycler = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "for samples, color in zip(true_samples, color_cycler):\n",
    "    embedding = model.embed(samples)\n",
    "    log_probs = model.log_prob(grid, embedding)\n",
    "\n",
    "    density = np.exp(log_probs)\n",
    "    density = np.where(density >= 0.02, density, 0)\n",
    "    density = np.reshape(density, [grid_size, grid_size])\n",
    "\n",
    "    cmap = create_gradient_cmap(color)\n",
    "\n",
    "    plt.contourf(x, y, density, levels=10, cmap=cmap, norm='logit')\n",
    "    plt.scatter(samples[:, 0], samples[:, 1], s=20.0, color=color, marker='.')\n",
    "\n",
    "plt.xlabel(r'$x_1$')\n",
    "plt.ylabel(r'$x_2$')\n",
    "plt.grid()\n",
    "\n",
    "plt.savefig('donuts-densities.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f9f27d1-fc68-4b35-9f6e-e7d0a1ea606f",
   "metadata": {},
   "source": [
    "## Embedding Arithmetic\n",
    "\n",
    "Embeddings don't just allow for reconstruction, often they can be used to interpolate or do arithmetic.  Is that true for distribution embeddings?\n",
    "\n",
    "If $z_1$ and $z_2$ are two embeddings of distributions.  What distribution might the embedding $\\frac{1}{2}z_1 + \\frac{1}{2}z_2$ correspond to?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6fb2cdf-c685-48f5-a25e-6b5fe94bd4a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "start = generate_data(\n",
    "    2, 1024,\n",
    "    mean=jnp.array([[0, -1], [0, 1]]).reshape(2, 1, 2),\n",
    "    radius=jnp.array([0.7, 0.7]).reshape(2, 1, 1),\n",
    "    width=jnp.array([0.25, 0.25]).reshape(2, 1, 1)\n",
    ")\n",
    "\n",
    "num_steps = 6\n",
    "zs = jax.vmap(model.embed)(start)\n",
    "zs = jnp.concatenate(\n",
    "    (zs[0:1], jnp.concatenate([(num_steps - i) / (num_steps + 1) * zs[0:1] + (i + 1) / (num_steps + 1) * zs[1:2] for i in range(num_steps)]), zs[1:2])\n",
    ")\n",
    "\n",
    "key, *generate_keys = jax.random.split(key, 1 + zs.shape[0])\n",
    "xs = jax.vmap(model.generate, in_axes=(0, None))(zs, 1024, key=jnp.array(generate_keys))\n",
    "\n",
    "fig = plt.figure()\n",
    "fig.set_figwidth(16)\n",
    "\n",
    "for i in range(2 + num_steps):\n",
    "    plt.subplot(1, 2 + num_steps, i + 1)\n",
    "\n",
    "    seaborn.kdeplot({'x': xs[i, :, 0], 'y': xs[i, :, 1]}, x='x', y='y', fill=True, alpha=0.3)\n",
    "\n",
    "    plt.gca().set_ylim(-2.5, 2.5)\n",
    "    plt.gca().set_xlim(-1.5, 1.5)\n",
    "    plt.gca().set_xlabel(None)\n",
    "    plt.gca().set_ylabel(None)\n",
    "    plt.grid()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3e0dd29-98c4-49e9-ab22-c4887463dcb6",
   "metadata": {},
   "source": [
    "What about the directions in embedding space?  Consider $z_2 - z_1$.  What if we apply that difference to some embedding $z_3$?  Does it apply the same change?\n",
    "\n",
    "Here we start with two embeddings whose mean differ on the x-axis.  If we apply that difference to a third distribution, does it shift its mean on the x-axis?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5637cc1-4d3a-4c1f-98f6-a4f1c5b88157",
   "metadata": {},
   "outputs": [],
   "source": [
    "start = generate_data(\n",
    "    3, 1024, \n",
    "    mean=jnp.array([[-1, -1], [-1, 1], [1, -1]]).reshape(-1, 1, 2),\n",
    "    radius=jnp.array([0.7, 0.7, 0.7]).reshape(-1, 1, 1),\n",
    "    width=jnp.array([0.25, 0.25, 0.25]).reshape(-1, 1, 1)\n",
    ")\n",
    "\n",
    "zs = jax.vmap(model.embed)(start)\n",
    "zs = jnp.concatenate((zs, jnp.array([zs[1] - zs[0] + zs[2]])))\n",
    "\n",
    "key, *generate_keys = jax.random.split(key, 1 + zs.shape[0])\n",
    "xs = jax.vmap(model.generate, in_axes=(0, None))(zs, 1024, key=jnp.array(generate_keys))\n",
    "\n",
    "for i in range(4):\n",
    "    seaborn.kdeplot({'x': xs[i, :, 0], 'y': xs[i, :, 1]}, x='x', y='y', fill=True, alpha=0.3)\n",
    "\n",
    "plt.grid()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
