{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fb2b2856",
   "metadata": {},
   "source": [
    "# Continuous Normalizing Flow tutorial: training ODE generative models using maximum likelihood"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e98247e-44ac-4ab9-ad1e-ff058dbb6215",
   "metadata": {},
   "source": [
    "This implements a [continuous normalizing flow (CNF)](https://arxiv.org/abs/1806.07366) trained using maximum likelihood.\n",
    "\n",
    "To compute the likelihood of a sample $x_1$ we use the instantaneous change of variables formula integrated over time that is we have\n",
    "\n",
    "$$\n",
    "\\begin{pmatrix}\n",
    "\\partial x_t / \\partial t \\\\\n",
    "\\partial \\log p(x_t) / \\partial t\n",
    "\\end{pmatrix} = \n",
    " \\begin{pmatrix}\n",
    "f(t, x_t)\\\\\n",
    "-\\text{tr}(\\partial f / \\partial x_t)\n",
    "\\end{pmatrix} \n",
    "$$\n",
    "\n",
    "which is implemented as a $d+1$ dimensional system. There are two common ways to calculate $\\partial \\log p(x_t) / \\partial t$.\n",
    "* Exact calcuation of the trace of the Jacobian with essentially $D$ calls of $f$.\n",
    "* Hutchinson trace estimator either with a normal distribution or Rademacher distribution. Which uses\n",
    "  $$\n",
    "    \\text{tr}(\\partial f / \\partial x_t) = \\mathbb{E}_{\\epsilon} \\left [ \\epsilon^T [\\partial f / \\partial x_t] \\epsilon \\right ]\n",
    "  $$\n",
    "  and can be used with a single call to $f$. $\\epsilon$ must be distributed such that $\\mathbb{E}(\\epsilon) = 0$ and $\\text{Cov}(\\epsilon) = I$. Most often Gaussian or Rademacher distributions are used, and are both implemented here.\n",
    "\n",
    "As compared to flow matching methods, this requires a calculation of the trace of the Hessian and backpropagation through time so is signficantly slower and more numerically unstable to train.\n",
    "\n",
    "Note: Requires a version of torch with `vmap` and `torch.func.jacrev`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2035a615",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import torch\n",
    "from torch.distributions import MultivariateNormal\n",
    "from torchdyn.core import NeuralODE\n",
    "\n",
    "from torchcfm.models import MLP\n",
    "from torchcfm.utils import plot_trajectories, sample_moons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b40e404c-ef9a-4242-8c47-ca8e7a197271",
   "metadata": {},
   "outputs": [],
   "source": [
    "class torch_wrapper(torch.nn.Module):\n",
    "    \"\"\"Wraps model to torchdyn compatible format.\"\"\"\n",
    "\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "\n",
    "    def forward(self, t, x, *args, **kwargs):\n",
    "        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))\n",
    "\n",
    "\n",
    "def exact_div_fn(u):\n",
    "    \"\"\"Accepts a function u:R^D -> R^D.\"\"\"\n",
    "    J = torch.func.jacrev(u)\n",
    "    return lambda x, *args: torch.trace(J(x))\n",
    "\n",
    "\n",
    "def div_fn_hutch_trace(u):\n",
    "    def div_fn(x, eps):\n",
    "        _, vjpfunc = torch.func.vjp(u, x)\n",
    "        return (vjpfunc(eps)[0] * eps).sum()\n",
    "\n",
    "    return div_fn\n",
    "\n",
    "\n",
    "class cnf_wrapper(torch.nn.Module):\n",
    "    \"\"\"Wraps model to a torchdyn compatible CNF format.\n",
    "    Appends an additional dimension representing the change in likelihood\n",
    "    over time.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, model, likelihood_estimator=\"exact\"):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "        self.div_fn, self.eps_fn = self.get_div_and_eps(likelihood_estimator)\n",
    "\n",
    "    def get_div_and_eps(self, likelihood_estimator):\n",
    "        if likelihood_estimator == \"exact\":\n",
    "            return exact_div_fn, None\n",
    "        if likelihood_estimator == \"hutch_gaussian\":\n",
    "            return div_fn_hutch_trace, torch.randn_like\n",
    "        if likelihood_estimator == \"hutch_rademacher\":\n",
    "\n",
    "            def eps_fn(x):\n",
    "                return torch.randint_like(x, low=0, high=2).float() * 2 - 1.0\n",
    "\n",
    "            return div_fn_hutch_trace, eps_fn\n",
    "        raise NotImplementedError(\n",
    "            f\"likelihood estimator {likelihood_estimator} is not implemented\"\n",
    "        )\n",
    "\n",
    "    def forward(self, t, x, *args, **kwargs):\n",
    "        t = t.squeeze()\n",
    "        x = x[..., :-1]\n",
    "\n",
    "        def vecfield(y):\n",
    "            return self.model(torch.cat([y, t[None]]))\n",
    "\n",
    "        if self.eps_fn is None:\n",
    "            div = torch.vmap(self.div_fn(vecfield))(x)\n",
    "        else:\n",
    "            div = torch.vmap(self.div_fn(vecfield))(x, self.eps_fn(x))\n",
    "        dx = self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))\n",
    "        return torch.cat([dx, div[:, None]], dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faf18883",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "dim = 2\n",
    "batch_size = 256\n",
    "model = MLP(dim=dim, time_varying=True).to(device)\n",
    "prior = MultivariateNormal(torch.zeros(dim, device=device), torch.eye(dim, device=device))\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "steps = 100\n",
    "cnf = NeuralODE(\n",
    "    cnf_wrapper(model, likelihood_estimator=\"exact\"), solver=\"euler\", sensitivity=\"adjoint\"\n",
    ")\n",
    "node = NeuralODE(torch_wrapper(model), solver=\"euler\", sensitivity=\"adjoint\")\n",
    "\n",
    "start = time.time()\n",
    "for k in range(1000):\n",
    "    optimizer.zero_grad()\n",
    "    x1 = sample_moons(batch_size).to(device)\n",
    "    x1_with_ll = torch.cat([x1, torch.zeros(batch_size, 1, device=device)], dim=-1)\n",
    "    x0_with_ll = cnf.trajectory(x1_with_ll, t_span=torch.linspace(1, 0, steps + 1, device=device))[\n",
    "        -1\n",
    "    ]\n",
    "    logprob = prior.log_prob(x0_with_ll[..., :-1]) + x0_with_ll[..., -1]\n",
    "    loss = -torch.mean(logprob)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if (k + 1) % 200 == 0:\n",
    "        end = time.time()\n",
    "        print(f\"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}\")\n",
    "        start = end\n",
    "\n",
    "        with torch.no_grad():\n",
    "            traj = node.trajectory(\n",
    "                torch.randn(1024, 2, device=device),\n",
    "                t_span=torch.linspace(0, 1, steps + 1, device=device),\n",
    "            )\n",
    "            plot_trajectories(traj.cpu().numpy())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}