{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "547dfe85",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Conditional Flow Matching\n",
    "\n",
    "This notebook is a self-contained example of conditional flow matching. We implement a number of different simulation-free methods for learning flow models. They differ based on the interpolant used and the loss function used to train them.\n",
    "\n",
    "In this notebook we implement 5 models that can map from a source distribution $q_0$ to a target distribution $q_1$:\n",
    "* Conditional Flow Matching (CFM)\n",
    "    * This is equivalent to the basic (non-rectified) formulation of \"Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow\" [(Liu et al. 2023)](https://openreview.net/forum?id=XVjTT1nw5z)\n",
    "    * Is siexampler to \"Stochastic Interpolants\" [(Anonymous et al. 2023)](https://openreview.net/forum?id=li7qeBbCR1t) with a non-variance preserving interpolant.\n",
    "    * Is siexampler to \"Flow Matching\" [(Anonymous et al. 2023)](https://openreview.net/forum?id=PqvMRDCJT9t) but conditions on both source and target.\n",
    "* Optimal Transport CFM (OT-CFM), which directly optimizes for dynamic optimal transport\n",
    "* Schrödinger Bridge CFM (SB-CFM), which optimizes for Schrödinger Bridge probability paths\n",
    "* \"Building Normalizing Flows with Stochastic Interpolants\" [(Anonymous et al. 2023)](https://openreview.net/forum?id=li7qeBbCR1t) this corresponds to \"VP-CFM\" in our README referring to its variance preserving properties.\n",
    "* \"Action Matching: Learning Stochastic Dynamics From Samples\" [(Neklyudov et al. 2022)](https://arxiv.org/abs/2210.06662)\n",
    "\n",
    "Note that this Flow Matching is different from the Generative Flow Network Flow Matching losses. Here we specifically regress against continuous flows, rather than matching inflows and outflows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb51734b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import os\n",
    "import time\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import ot as pot\n",
    "import torch\n",
    "import torchdyn\n",
    "from torchdyn.core import NeuralODE\n",
    "from torchdyn.datasets import generate_moons\n",
    "\n",
    "from torchcfm.conditional_flow_matching import *\n",
    "from torchcfm.models.models import *\n",
    "from torchcfm.utils import *\n",
    "\n",
    "savedir = \"models/8gaussian-moons\"\n",
    "os.makedirs(savedir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b241a60",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Conditional Flow Matching\n",
    "\n",
    "First we implement the basic conditional flow matching. As in the paper, we have\n",
    "$$\n",
    "\\begin{align}\n",
    "z &= (x_0, x_1) \\\\\n",
    "q(z) &= q(x_0)q(x_1) \\\\\n",
    "p_t(x | z) &= \\mathcal{N}(x | t * x_1 + (1 - t) * x_0, \\sigma^2) \\\\\n",
    "u_t(x | z) &= x_1 - x_0\n",
    "\\end{align}\n",
    "$$\n",
    "When $\\sigma = 0$ this is equivalent to zero-steps of rectified flow. We find that small $\\sigma$ helps to regularize the problem ymmv."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "176eb7fe",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "sigma = 0.1\n",
    "dim = 2\n",
    "batch_size = 256\n",
    "model = MLP(dim=dim, time_varying=True)\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "FM = ConditionalFlowMatcher(sigma=sigma)\n",
    "\n",
    "start = time.time()\n",
    "for k in range(20000):\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    x0 = sample_8gaussians(batch_size)\n",
    "    x1 = sample_moons(batch_size)\n",
    "\n",
    "    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)\n",
    "\n",
    "    vt = model(torch.cat([xt, t[:, None]], dim=-1))\n",
    "    loss = torch.mean((vt - ut) ** 2)\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if (k + 1) % 5000 == 0:\n",
    "        end = time.time()\n",
    "        print(f\"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}\")\n",
    "        start = end\n",
    "        node = NeuralODE(\n",
    "            torch_wrapper(model), solver=\"dopri5\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4\n",
    "        )\n",
    "        with torch.no_grad():\n",
    "            traj = node.trajectory(\n",
    "                sample_8gaussians(1024),\n",
    "                t_span=torch.linspace(0, 1, 100),\n",
    "            )\n",
    "            plot_trajectories(traj.cpu().numpy())\n",
    "torch.save(model, f\"{savedir}/cfm_v1.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e8057b9",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Optimal Transport Conditional Flow Matching\n",
    "\n",
    "Next we implement optimal transport conditional flow matching. As in the paper, here we have\n",
    "$$\n",
    "\\begin{align}\n",
    "z &= (x_0, x_1) \\\\\n",
    "q(z) &= \\pi(x_0, x_1) \\\\\n",
    "p_t(x | z) &= \\mathcal{N}(x | t * x_1 + (1 - t) * x_0, \\sigma^2) \\\\\n",
    "u_t(x | z) &= x_1 - x_0\n",
    "\\end{align}\n",
    "$$\n",
    "where $\\pi$ is the joint of an exact optimal transport matrix. We first sample random $x_0, x_1$, then resample according to the optimal transport matrix as computed with the python optimal transport package. We use the 2-Wasserstein distance with an $L^2$ ground distance for equivalence with dynamic optimal transport."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aed74817",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "sigma = 0.1\n",
    "dim = 2\n",
    "batch_size = 256\n",
    "model = MLP(dim=dim, time_varying=True)\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)\n",
    "\n",
    "start = time.time()\n",
    "for k in range(20000):\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    x0 = sample_8gaussians(batch_size)\n",
    "    x1 = sample_moons(batch_size)\n",
    "\n",
    "    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)\n",
    "\n",
    "    vt = model(torch.cat([xt, t[:, None]], dim=-1))\n",
    "    loss = torch.mean((vt - ut) ** 2)\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if (k + 1) % 5000 == 0:\n",
    "        end = time.time()\n",
    "        print(f\"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}\")\n",
    "        start = end\n",
    "        node = NeuralODE(\n",
    "            torch_wrapper(model), solver=\"dopri5\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4\n",
    "        )\n",
    "        with torch.no_grad():\n",
    "            traj = node.trajectory(\n",
    "                sample_8gaussians(1024),\n",
    "                t_span=torch.linspace(0, 1, 100),\n",
    "            )\n",
    "            plot_trajectories(traj.cpu().numpy())\n",
    "torch.save(model, f\"{savedir}/otcfm_v1.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a7d8251",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Schrödinger Bridge Conditional Flow Matching\n",
    "\n",
    "Next we implement Schrödinger Bridge conditional flow matching. As in the paper, here we have\n",
    "$$\n",
    "\\begin{align}\n",
    "z &= (x_0, x_1) \\\\\n",
    "q(z) &= \\pi_{2 \\sigma^2} (x_0, x_1) \\\\\n",
    "p_t(x | z) &= \\mathcal{N}( x \\mid t x_1 + (1 - t) x_0, t(1-t)\\sigma^2)\\\\\n",
    "u_t(x | z) &= \\frac{1-2t}{2t(1-t)}(x - ( t x_1 + (1-t)x_0) ) + (x_1 - x_0)\n",
    "\\end{align}\n",
    "$$\n",
    "where $\\pi_{2 \\sigma^2}$ is the joint of a **Sinkhorn** optimal transport matrix with regularization $2 \\sigma^2$. As in OT-CFM We first sample random $x_0, x_1$, then resample according to the optimal transport matrix as computed with the python optimal transport package. We use the 2-Wasserstein distance with an $L^2$ ground distance for equivalence with the probability flow of a Schrödinger Bridge with reference measure $\\sigma W$.\n",
    "\n",
    "Note that the drift $v_\\theta(t,x)$ we learn is *not* equivalent to the drift of the stochastic system, instead it is the drift of the equivalent probability flow ODE, however they are related through $p_t(x)$.\n",
    "\n",
    "Also note that we use a larger $\\sigma$ here both for convergence of the Sinkhorn algorithm and to make the differences more visible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82c0eab7",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "sigma = 0.5\n",
    "dim = 2\n",
    "batch_size = 256\n",
    "model = MLP(dim=dim, time_varying=True)\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "# For best performance, use ot_method=\"exact\". To follow the theory, use ot_method=\"sinkhorn\"\n",
    "FM = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma, ot_method=\"exact\")\n",
    "\n",
    "start = time.time()\n",
    "for k in range(20000):\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    x0 = sample_8gaussians(batch_size)\n",
    "    x1 = sample_moons(batch_size)\n",
    "\n",
    "    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)\n",
    "\n",
    "    vt = model(torch.cat([xt, t[:, None]], dim=-1))\n",
    "    loss = torch.mean((vt - ut) ** 2)\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if (k + 1) % 5000 == 0:\n",
    "        end = time.time()\n",
    "        print(f\"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}\")\n",
    "        start = end\n",
    "        node = NeuralODE(\n",
    "            torch_wrapper(model), solver=\"dopri5\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4\n",
    "        )\n",
    "        with torch.no_grad():\n",
    "            traj = node.trajectory(\n",
    "                sample_8gaussians(1024),\n",
    "                t_span=torch.linspace(0, 1, 100),\n",
    "            )\n",
    "            plot_trajectories(traj.cpu().numpy())\n",
    "torch.save(model, f\"{savedir}/sbcfm_v1.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5890705",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Stochastic Interpolants (Anonymous et al. 2023)\n",
    "\n",
    "Next we try a variant suggested by Anonymous et al. 2023, Stochastic Interpolants. This is an interesting interpolant because it has some variance preservation properties. Note that the authors consider $\\sigma = 0$ (i.e. a Dirac around $\\mu_t$) but we keep the general form for consistency. We also refer to this $t$ schedule as \"Variance Preserving\" and call this VP-CFM. The authors also consider optimizing over more general interpolants $I_t = \\alpha(t) x_0 + \\beta(t) x_1$ with minor constraints. In our notation, we have\n",
    "$$\n",
    "\\begin{align}\n",
    "z &= (x_0, x_1) \\\\\n",
    "q(z) &= q(x_0)q(x_1) \\\\\n",
    "p_t(x | z) &= \\mathcal{N}(x | \\cos \\left (\\frac{\\pi t}{2}  \\right ) x_0 + \\sin \\left (\\frac{\\pi t}{2}  \\right ) x_1, \\sigma^2) \\\\\n",
    "u_t(x | z) &= \\frac{\\pi}{2} \\left (\\cos (\\frac{\\pi t}{2}) x_1 - \\sin(\\frac{\\pi t}{2}) x_0 \\right )\n",
    "\\end{align}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8df2cf5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "sigma = 0.1\n",
    "dim = 2\n",
    "batch_size = 256\n",
    "model = MLP(dim=dim, time_varying=True)\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "pi = math.pi\n",
    "FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)\n",
    "\n",
    "start = time.time()\n",
    "for k in range(20000):\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    x0 = sample_8gaussians(batch_size)\n",
    "    x1 = sample_moons(batch_size)\n",
    "\n",
    "    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)\n",
    "\n",
    "    vt = model(torch.cat([xt, t[:, None]], dim=-1))\n",
    "    loss = torch.mean((vt - ut) ** 2)\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if (k + 1) % 5000 == 0:\n",
    "        end = time.time()\n",
    "        print(f\"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}\")\n",
    "        start = end\n",
    "        node = NeuralODE(\n",
    "            torch_wrapper(model), solver=\"dopri5\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4\n",
    "        )\n",
    "        with torch.no_grad():\n",
    "            traj = node.trajectory(\n",
    "                sample_8gaussians(1024),\n",
    "                t_span=torch.linspace(0, 1, 100),\n",
    "            )\n",
    "            plot_trajectories(traj.cpu().numpy())\n",
    "torch.save(model, f\"{savedir}/stochastic_interpolant_v1.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18ddf18b",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Action Matching (Neklyudov et al. 2022)\n",
    "\n",
    "Next we try a variant called action matching. Here we parametrize the velocity field $v_\\theta(t, x)$ as $\\nabla s_\\theta(t, x)$ where $s_\\theta(t, x): \\mathbb{R} \\times \\mathbb{R}^d \\to \\mathbb{R}$ is interpreted as the **action**. This is an interesting parameterization because of its link with optimal transport. Namely this velocity performs instantaneous optimal transport flow over $p_t(x)$. This is slightly different than optimal transport between marginals, but is also quite interesting. Action matching can be summarized in the following way:\n",
    "$$\n",
    "\\begin{align}\n",
    "z &= (x_0, x_1) \\\\\n",
    "q(z) &= q(x_0)q(x_1) \\\\\n",
    "p_t(x | z) &= \\mathcal{N}(x | (1-t) x_0 + t x_1, \\sigma^2)\\\\\n",
    "L_{AM}(\\theta) &= s_\\theta(0, x_0) - s_\\theta(1, x_1) + \\frac{1}{2} \\| \\nabla_x s_\\theta(t, x_t)\\|^2 + \\frac{\\partial}{\\partial t} s_\\theta(t, x_t)\n",
    "\\end{align}\n",
    "$$\n",
    "Note that the authors again consider $\\sigma = 0$ (i.e. a Dirac around $\\mu_t$) but we keep the general form for consistency assuming that $\\mathcal{N}(x | \\mu_t, 0)$ is a degenerate Dirac centered at $\\mu_t$. Our standard parameterization seems to be more difficult to fit with this loss (3-layer MLP with width 64 and SELU activations). It's unclear to me why this is the case, but as suggested in their repo using ReLU, Swish, Swish activations works much better."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdb486cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchcfm.models.models import GradModel\n",
    "\n",
    "# %%time\n",
    "dim = 2\n",
    "batch_size = 256\n",
    "action = MLP(dim=dim, out_dim=1, time_varying=True, w=64)\n",
    "model = GradModel(action)\n",
    "optimizer = torch.optim.Adam(action.parameters())\n",
    "\n",
    "start = time.time()\n",
    "for k in range(20000):\n",
    "    optimizer.zero_grad()\n",
    "    t = torch.rand(batch_size, 1).requires_grad_(True)\n",
    "    x0 = sample_8gaussians(batch_size)\n",
    "    x1 = sample_moons(batch_size)\n",
    "    xt = (t * x1 + (1 - t) * x0).detach().requires_grad_(True)\n",
    "    st = torch.sum(action(torch.cat([xt, t], dim=-1)))\n",
    "    dsdx, dsdt = torch.autograd.grad(st, (xt, t), create_graph=True, retain_graph=True)\n",
    "    xt.requires_grad, t.requires_grad = False, False\n",
    "    a0 = action(torch.cat([x0, torch.zeros(batch_size, 1)], dim=-1))\n",
    "    a1 = action(torch.cat([x1, torch.ones(batch_size, 1)], dim=-1))\n",
    "    loss = a0 - a1 + 0.5 * (dsdx**2).sum(1, keepdims=True) + dsdt\n",
    "    loss = loss.mean()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if (k + 1) % 5000 == 0:\n",
    "        end = time.time()\n",
    "        print(f\"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}\")\n",
    "        start = end\n",
    "        node = NeuralODE(\n",
    "            torch_wrapper(model), solver=\"euler\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4\n",
    "        )\n",
    "        # with torch.no_grad():\n",
    "        traj = node.trajectory(\n",
    "            sample_8gaussians(1024),\n",
    "            t_span=torch.linspace(0, 1, 100),\n",
    "        ).detach()\n",
    "        plot_trajectories(traj.cpu().numpy())\n",
    "torch.save(model, f\"{savedir}/action_matching_v1.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69b3c3de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%time\n",
    "\n",
    "\n",
    "class MLP2(torch.nn.Module):\n",
    "    def __init__(self, dim, out_dim=None, w=64, time_varying=False):\n",
    "        super().__init__()\n",
    "        self.time_varying = time_varying\n",
    "        if out_dim is None:\n",
    "            out_dim = dim\n",
    "        self.net = torch.nn.Sequential(\n",
    "            torch.nn.Linear(dim + (1 if time_varying else 0), w),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SiLU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SiLU(),\n",
    "            torch.nn.Linear(w, out_dim),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "\n",
    "dim = 2\n",
    "batch_size = 256\n",
    "action = MLP2(dim=dim, out_dim=1, time_varying=True, w=64)\n",
    "model = GradModel(action)\n",
    "optimizer = torch.optim.Adam(action.parameters())\n",
    "\n",
    "start = time.time()\n",
    "for k in range(20000):\n",
    "    optimizer.zero_grad()\n",
    "    t = torch.rand(batch_size, 1).requires_grad_(True)\n",
    "    x0 = sample_8gaussians(batch_size)\n",
    "    x1 = sample_moons(batch_size)\n",
    "    xt = (t * x1 + (1 - t) * x0).detach().requires_grad_(True)\n",
    "    st = torch.sum(action(torch.cat([xt, t], dim=-1)))\n",
    "    dsdx, dsdt = torch.autograd.grad(st, (xt, t), create_graph=True, retain_graph=True)\n",
    "    xt.requires_grad, t.requires_grad = False, False\n",
    "    a0 = action(torch.cat([x0, torch.zeros(batch_size, 1)], dim=-1))\n",
    "    a1 = action(torch.cat([x1, torch.ones(batch_size, 1)], dim=-1))\n",
    "    loss = a0 - a1 + 0.5 * (dsdx**2).sum(1, keepdims=True) + dsdt\n",
    "    loss = loss.mean()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if (k + 1) % 5000 == 0:\n",
    "        end = time.time()\n",
    "        print(f\"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}\")\n",
    "        start = end\n",
    "        node = NeuralODE(\n",
    "            torch_wrapper(model), solver=\"euler\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4\n",
    "        )\n",
    "        # with torch.no_grad():\n",
    "        traj = node.trajectory(\n",
    "            sample_8gaussians(1024),\n",
    "            t_span=torch.linspace(0, 1, 100),\n",
    "        ).detach()\n",
    "        plot_trajectories(traj.cpu().numpy())\n",
    "torch.save(model, f\"{savedir}/action_matching_swish_v1.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80ec0276",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed3f921b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchcfm",
   "language": "python",
   "name": "torchcfm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}