{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7cbb8e2-0efb-40de-ba67-27c8052f0ad4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "num_threads = \"8\"\n",
    "os.environ[\"OMP_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"OPENBLAS_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"MKL_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"VECLIB_MAXIMUM_THREADS\"] = num_threads\n",
    "os.environ[\"NUMEXPR_NUM_THREADS\"] = num_threads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "080ac2e1-77cb-4bb5-ada0-6f5555ab37e5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchdiffeq import odeint\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "import importlib\n",
    "import matplotlib.pyplot as plt\n",
    "import fm\n",
    "importlib.reload(fm)\n",
    "torch.set_default_dtype(torch.float32)\n",
    "device = torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fff97bb6-f32f-46c6-854e-2ec5bc5c4f97",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed = 1\n",
    "dim = 10\n",
    "N = 128\n",
    "data = torch.load(f\"../gaussian/data_seed_{seed}_dim_{dim}_N_{N}.pkl\")\n",
    "x0, x1 = data['x0'], data['x1']\n",
    "A, mu = data['A'], data['mu']\n",
    "d, sigma = data['d'], data['sigma']\n",
    "U = data['U']\n",
    "ts = torch.tensor([0., 1.])\n",
    "\n",
    "U0 = torch.tensor([[-0.7988, -0.6016],\n",
    "        [-0.6016,  0.7988]])\n",
    "U = U @ U0\n",
    "\n",
    "N = 64\n",
    "bm = fm.LinearBridgeMatcher(A, mu)\n",
    "x0 = torch.vstack([torch.randn(N // 2, 2)*0.1 - 0.5, \n",
    "                   torch.randn(N // 2, 2)*0.25 + 0.5\n",
    "                  ]).to(device)\n",
    "x0 = x0 @ U.T \n",
    "x1 = torch.vstack([torch.randn(N // 2, 2)*0.1 - 2.5, \n",
    "                   torch.randn(N // 2, 2)*0.5 + 2.5,\n",
    "                  ]).to(device)\n",
    "x1 = x1 @ U.T\n",
    "t = torch.rand(N).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99685ac2-6415-497e-8358-d7219815a963",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (2.5, 2.5))\n",
    "plt.imshow(A, vmin = -1, vmax = 1, cmap = \"RdBu_r\")\n",
    "plt.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "357be259-ced0-4450-85aa-3072a20c415d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "importlib.reload(fm)\n",
    "otfm = fm.LinearEntropicOTFM(torch.vstack([x0, x1]), \n",
    "                      torch.hstack([torch.full((x0.shape[0], ), 0), torch.full((x0.shape[0], ), 1)]), \n",
    "                      ts = torch.tensor([0., 1.], dtype = torch.float32),\n",
    "                      sigma = sigma,\n",
    "                      A = A,\n",
    "                      mu = mu,\n",
    "                      T = 2,\n",
    "                      dim = d,\n",
    "                      device = device\n",
    "                      )\n",
    "otfm_null = fm.EntropicOTFM(torch.vstack([x0, x1]), \n",
    "                      torch.hstack([torch.full((x0.shape[0], ), 0), torch.full((x0.shape[0], ), 1)]), \n",
    "                      ts = torch.tensor([0., 1.], dtype = torch.float32),\n",
    "                      sigma = sigma,\n",
    "                      T = 2,\n",
    "                      dim = d,\n",
    "                      device = device\n",
    "                      )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ebfec71-571f-4d5f-a2f6-00b2041b16c0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "s_model = fm.MLP(d = d, hidden_sizes = [64, 64, 64], time_varying=True, activation = torch.nn.ReLU)\n",
    "v_model = fm.MLP(d = d, hidden_sizes = [64, 64, 64], time_varying=True, activation = torch.nn.ReLU)\n",
    "\n",
    "alpha = 0.5\n",
    "from tqdm import tqdm\n",
    "optim = torch.optim.AdamW(list(s_model.parameters()) + list(v_model.parameters()), 3e-3)\n",
    "for i in tqdm(range(1000)):\n",
    "    _x, _s, _u, _t, _t_orig = otfm.sample_bridges_flows(batch_size = 64)\n",
    "    optim.zero_grad()\n",
    "    s_fit = s_model(_t, _x)\n",
    "    v_fit = v_model(_t, _x)\n",
    "    L_score = torch.mean(((_t_orig * (1-_t_orig)) * (s_fit - _s))**2)\n",
    "    L_flow = torch.mean((_t_orig * (1-_t_orig)*(v_fit - _u))**2)\n",
    "    L = (1-alpha)*L_score + alpha*L_flow\n",
    "    if i % 100 == 0:\n",
    "        print(L_score.item(), L_flow.item())\n",
    "    L.backward()\n",
    "    optim.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4cdbe9e-d3f9-4eeb-90d9-4de84fd6a6b5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "s_model_null = fm.MLP(d = d, hidden_sizes = [64, 64, 64], time_varying=True, activation = torch.nn.ReLU)\n",
    "v_model_null = fm.MLP(d = d, hidden_sizes = [64, 64, 64], time_varying=True, activation = torch.nn.ReLU)\n",
    "\n",
    "alpha = 0.5\n",
    "from tqdm import tqdm\n",
    "optim = torch.optim.AdamW(list(s_model_null.parameters()) + list(v_model_null.parameters()), 3e-3)\n",
    "for i in tqdm(range(1_000)):\n",
    "    _x, _s, _u, _t, _t_orig = otfm_null.sample_bridges_flows(batch_size = 64)\n",
    "    optim.zero_grad()\n",
    "    s_fit = s_model_null(_t, _x)\n",
    "    v_fit = v_model_null(_t, _x)\n",
    "    L_score = torch.mean(((_t_orig * (1-_t_orig)) * (s_fit - _s))**2)\n",
    "    L_flow = torch.mean((_t_orig * (1-_t_orig)*(v_fit - _u))**2)\n",
    "    L = (1-alpha)*L_score + alpha*L_flow\n",
    "    if i % 100 == 0:\n",
    "        print(L_score.item(), L_flow.item())\n",
    "    L.backward()\n",
    "    optim.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17ec6219-ce5a-4860-8f00-80ac45cb80b1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torchsde\n",
    "# _v, _s, what = v_model, s_model, \"OU\"\n",
    "_v, _s, what = v_model_null, s_model_null, \"BM\"\n",
    "xlims = (-4.5, 4.5)\n",
    "ylims = (-4.5, 4.5)\n",
    "\n",
    "sde = fm.SDE(lambda t, x: _v(t, x) + sigma**2 / 2 * _s(t, x), sigma)\n",
    "_T = 100\n",
    "with torch.no_grad():\n",
    "    xs_sde = torchsde.sdeint(sde, torch.tensor(x0), torch.linspace(0, 1, _T), method = \"euler\")\n",
    "    xs_ode = odeint(lambda t, x: _v(t, x), torch.tensor(x0), torch.linspace(0, 1, _T))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63474aaa-472c-4cf9-8a3f-62a89d964850",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def _project_traj(xs, U):\n",
    "    return torch.bmm(xs, U.unsqueeze(0).expand(xs.shape[0], *U.shape));\n",
    "xs_sde_proj = _project_traj(xs_sde, U)\n",
    "xs_ode_proj = _project_traj(xs_ode, U)\n",
    "x0_proj, x1_proj = x0 @ U, x1 @ U"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71481868-b6e3-46b1-a39c-54479325ba8d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "k = 0\n",
    "plt.figure(figsize = (5, 2.5))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(torch.linspace(0, 1, _T), xs_sde_proj[..., k], color = 'blue', alpha = 0.1);\n",
    "plt.scatter(torch.zeros_like(x0_proj[:, k]), x0_proj[:, k], c = 'green', label = \"$p_0$\")\n",
    "plt.scatter(torch.ones_like(x1_proj[:, k]), x1_proj[:, k], c = 'red', label = \"$p_1$\")\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x$\")\n",
    "plt.legend()\n",
    "plt.title(\"SDE\")\n",
    "plt.ylim(*ylims)\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(torch.linspace(0, 1, _T), xs_ode_proj[..., k], color = 'blue', alpha = 0.1);\n",
    "plt.scatter(torch.zeros_like(x0_proj[:, k]), x0_proj[:, k], c = 'green', label = \"$p_0$\")\n",
    "plt.scatter(torch.ones_like(x1_proj[:, k]), x1_proj[:, k], c = 'red', label = \"$p_1$\")\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x$\")\n",
    "plt.legend()\n",
    "plt.ylim(*ylims)\n",
    "plt.title(\"PF-ODE\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/SB_general_{what}_1d_v_time_seed_{seed}_dim_{dim}_N_{N}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7009e32-94e5-4095-a6d1-b2dabf04ceba",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "x = np.linspace(-5.5, 5.5, 100)\n",
    "y = np.linspace(-5.5, 5.5, 100)\n",
    "X, Y = np.meshgrid(x, y)\n",
    "P = np.eye(d)[:, range(2)]\n",
    "# u, v = P.T @ (A.numpy() @ ((P @ np.vstack([X.flatten(), Y.flatten()])) - mu.numpy()[:, None]))\n",
    "u, v = U.numpy().T @ (A.numpy() @ ((U.numpy() @ np.vstack([X.flatten(), Y.flatten()])) - mu.numpy()[:, None]))\n",
    "u = u.reshape(X.shape)\n",
    "v = v.reshape(X.shape)\n",
    "# don't plot origin\n",
    "_mu = mu @ U\n",
    "_idx = ((X - _mu[0].item())**2 < 1) & ((Y - _mu[1].item())**2 < 1)\n",
    "u[_idx] = 0\n",
    "v[_idx] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad1050ae-120d-46dc-b759-d08c21708b07",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (5, 2.5))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.streamplot(X, Y, u, v, density=0.3, color='k', linewidth=0.5, arrowsize=1)\n",
    "plt.plot(xs_sde_proj[..., 0], xs_sde_proj[..., 1], color = 'blue', alpha = 0.1);\n",
    "plt.scatter(x0_proj[:, 0], x0_proj[:, 1], c = 'green', label = \"$p_0$\")\n",
    "plt.scatter(x1_proj[:, 0], x1_proj[:, 1], c = 'red', label = \"$p_1$\")\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\")\n",
    "# plt.legend()\n",
    "plt.title(\"SDE\")\n",
    "plt.xlim(*xlims); plt.ylim(*ylims); \n",
    "plt.subplot(1, 2, 2)\n",
    "plt.streamplot(X, Y, u, v, density=0.3, color='k', linewidth=0.5, arrowsize=1)\n",
    "plt.plot(xs_ode_proj[..., 0], xs_ode_proj[..., 1], color = 'blue', alpha = 0.1);\n",
    "plt.scatter(x0_proj[:, 0], x0_proj[:, 1], c = 'green', label = \"$p_0$\")\n",
    "plt.scatter(x1_proj[:, 0], x1_proj[:, 1], c = 'red', label = \"$p_1$\")\n",
    "# plt.legend()\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\")\n",
    "plt.title(\"PF-ODE\")\n",
    "plt.xlim(*xlims); plt.ylim(*ylims);\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/SB_general_{what}_2d_seed_{seed}_dim_{dim}_N_{N}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5df39cdc-2c56-4302-b49b-5efb6e0eb51f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# varying noise level in forward simulation\n",
    "xs_sde_sigma = {}\n",
    "for _sigma in [0.01, 0.5, 1.0, 2.5]:\n",
    "    sde = fm.SDE(lambda t, x: _v(t, x) + _sigma**2 / 2 * _s(t, x), _sigma)\n",
    "    _T = 100\n",
    "    with torch.no_grad():\n",
    "        xs_sde_sigma[_sigma] = torchsde.sdeint(sde, torch.tensor(x0, dtype = torch.float32), torch.linspace(0, 1, _T), method = \"euler\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fd6fbae-86f6-4064-af82-2ace3c49ea1f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (10, 2.5))\n",
    "for i, (_sigma, _xs) in enumerate(xs_sde_sigma.items()):\n",
    "    _xs_proj = _project_traj(_xs, U)\n",
    "    plt.subplot(1, 4, i+1)\n",
    "    plt.streamplot(X, Y, u, v, density=0.3, color='k', linewidth=0.5, arrowsize=1)\n",
    "    plt.plot(_xs_proj[..., 0], _xs_proj[..., 1], color = 'blue', alpha = 0.1);\n",
    "    plt.scatter(x0_proj[:, 0], x0_proj[:, 1], c = 'green')\n",
    "    plt.scatter(x1_proj[:, 0], x1_proj[:, 1], c = 'red')\n",
    "    plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\")\n",
    "    plt.xlim(*xlims); plt.ylim(*ylims)\n",
    "    plt.title(f\"$\\\\sigma = {_sigma}$\")\n",
    "    plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/SB_general_{what}_2d_var_noise_seed_{seed}_dim_{dim}_N_{N}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "706da58e-a396-472b-aace-cf67826cb561",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sde_ou = fm.SDE(lambda t, x: v_model(t, x) + sigma**2 / 2 * s_model(t, x), sigma)\n",
    "sde_bm = fm.SDE(lambda t, x: v_model_null(t, x) + sigma**2 / 2 * s_model_null(t, x), sigma)\n",
    "with torch.no_grad():\n",
    "    xs_sde_ou = torchsde.sdeint(sde_ou, torch.tensor(x0), torch.linspace(0, 1, _T), method = \"euler\")\n",
    "    xs_sde_bm = torchsde.sdeint(sde_bm, torch.tensor(x0), torch.linspace(0, 1, _T), method = \"euler\")\n",
    "xs_sde_ou_proj = _project_traj(xs_sde_ou, U)\n",
    "xs_sde_bm_proj = _project_traj(xs_sde_bm, U)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3745568-ee0e-4762-b050-85cd0a3a2c54",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_ts = torch.linspace(0, 1, 15)\n",
    "_ys = torch.linspace(*ylims, 15)\n",
    "with torch.no_grad():\n",
    "    score_ou = torch.vstack([s_model(_t, _ys.unsqueeze(1) * U[:, 0].unsqueeze(0)) @ U[:, 0] for _t in _ts]).T\n",
    "    score_bm = torch.vstack([s_model_null(_t, _ys.unsqueeze(1) * U[:, 0].unsqueeze(0)) @ U[:, 0] for _t in _ts]).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0db39f6-f40b-47a6-8cd1-a22e3ffe3b15",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "X, Y = np.meshgrid(_ts, _ys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a382fac6-0970-4f8c-a99d-f5a6a06256dd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (5, 2.5))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.quiver(X, Y, torch.zeros_like(score_ou), score_ou, score_ou.abs(), cmap = \"RdBu_r\", scale_units = 'y')\n",
    "plt.plot(torch.linspace(0, 1, _T), xs_sde_ou_proj[..., k], color = 'blue', alpha = 0.1);\n",
    "plt.scatter(torch.zeros_like(x0_proj[:, k]), x0_proj[:, k], c = 'green', label = \"$p_0$\")\n",
    "plt.scatter(torch.ones_like(x1_proj[:, k]), x1_proj[:, k], c = 'red', label = \"$p_1$\")\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x$\")\n",
    "plt.ylim(*ylims)\n",
    "plt.title(\"mvOU-OTFM\")\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.quiver(X, Y, torch.zeros_like(score_bm), score_bm, score_bm.abs(), cmap = \"RdBu_r\", scale_units = 'y')\n",
    "plt.plot(torch.linspace(0, 1, _T), xs_sde_bm_proj[..., k], color = 'blue', alpha = 0.1);\n",
    "plt.scatter(torch.zeros_like(x0_proj[:, k]), x0_proj[:, k], c = 'green', label = \"$p_0$\")\n",
    "plt.scatter(torch.ones_like(x1_proj[:, k]), x1_proj[:, k], c = 'red', label = \"$p_1$\")\n",
    "plt.ylim(*ylims)\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x$\")\n",
    "plt.title(\"BM-OTFM\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/SB_general_comparison_2d_seed_{seed}_dim_{dim}_N_{N}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c4a7577-56c9-4300-9b40-71a6fba69093",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sde_ref = fm.SDE(lambda t, x: (x - mu) @ A.T, sigma)\n",
    "with torch.no_grad():\n",
    "    xs_sde_ref = torchsde.sdeint(sde_ref, torch.tensor(x0), torch.linspace(0, 1, _T), method = \"euler\")\n",
    "xs_sde_ref_proj = _project_traj(xs_sde_ref, U)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47831f93-668e-4713-b3be-4339dd05cfeb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.mixture import GaussianMixture\n",
    "\n",
    "fig = plt.figure(figsize=(5, 5))\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "x = np.linspace(*ylims, 100)\n",
    "y = np.zeros_like(x)  # The curve is in the x-z plane (y=0)\n",
    "\n",
    "gmm0 = GaussianMixture(n_components=2, random_state=42)\n",
    "gmm0.fit(x0_proj[:, 0].unsqueeze(1))\n",
    "gmm1 = GaussianMixture(n_components=2, random_state=42)\n",
    "gmm1.fit(x1_proj[:, 0].unsqueeze(1))\n",
    "\n",
    "z = np.exp(gmm0.score_samples(x[:, None]))\n",
    "ax.plot(x, y, z, 'g-', linewidth=2, label='Curve')\n",
    "z_min = 0\n",
    "xx = np.vstack([x, x])  # x-coordinates\n",
    "yy = np.vstack([y, y])  # y-coordinates (all zeros in this case)\n",
    "zz = np.vstack([z, np.zeros_like(x)])  # z-coordinates: curve values and zeros\n",
    "surf = ax.plot_surface(xx, yy, zz, color = 'green', alpha=0.7, \n",
    "                       linewidth=0, antialiased=True)\n",
    "\n",
    "z = np.exp(gmm1.score_samples(x[:, None]))\n",
    "y = np.ones_like(x)  # The curve is in the x-z plane (y=0)\n",
    "ax.plot(x, y, z, 'r-', linewidth=2, label='Curve')\n",
    "z_min = 0\n",
    "xx = np.vstack([x, x])  # x-coordinates\n",
    "yy = np.vstack([y, y])  # y-coordinates (all zeros in this case)\n",
    "zz = np.vstack([z, np.zeros_like(x)])  # z-coordinates: curve values and zeros\n",
    "surf = ax.plot_surface(xx, yy, zz, color = 'red', alpha=0.7, \n",
    "                       linewidth=0, antialiased=True)\n",
    "\n",
    "for i in np.random.permutation(xs_sde_ref_proj.shape[1]):\n",
    "    ax.plot(xs_sde_ref_proj[:, i, 0], torch.linspace(0, 1, _T), alpha = 0.1, c = 'k')\n",
    "\n",
    "for i in np.random.permutation(xs_sde_ou_proj.shape[1])[:16]:\n",
    "    ax.plot(xs_sde_ou_proj[:, i, 0], torch.linspace(0, 1, _T), alpha = 0.4, c = 'blue')\n",
    "\n",
    "# Disable the box/panes in the 3D plot\n",
    "# Method 1: Turn off all panes\n",
    "ax.xaxis.pane.fill = False\n",
    "ax.yaxis.pane.fill = False\n",
    "ax.zaxis.pane.fill = False\n",
    "\n",
    "ax.xaxis.pane.set_edgecolor('none')\n",
    "ax.yaxis.pane.set_edgecolor('none')\n",
    "ax.zaxis.pane.set_edgecolor('none')\n",
    "\n",
    "ax.set_axis_off()\n",
    "\n",
    "ax.set_xlabel('$x$')\n",
    "ax.set_ylabel('$t$')\n",
    "ax.view_init(elev=30, azim=-35)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/concept.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbfdbbb3-cc67-40b1-89f6-155c40871d7d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_x0 = x0[0, :].unsqueeze(0) * 0.5\n",
    "_x1 = x1[-1, :].unsqueeze(0)\n",
    "_x0_proj = _x0 @ U\n",
    "_x1_proj = _x1 @ U\n",
    "_t = torch.linspace(0, 1, _T)\n",
    "sde_bridge = fm.SDE(lambda t, x: (x - mu) @ A.T + otfm.bm._bridge_ctrl(x, t.repeat(x.shape[0]), _x0.repeat((x.shape[0], 1)), _x1.repeat((x.shape[0], 1))), sigma)\n",
    "with torch.no_grad():\n",
    "    xs_sde_bridge = torchsde.sdeint(sde_bridge, torch.tensor(_x0).repeat((8, 1)), _t, method = \"euler\")\n",
    "xs_sde_bridge_proj = _project_traj(xs_sde_bridge, U)\n",
    "\n",
    "_std = (U[:, 0].T @ otfm.bm._bridge_cov(_t, _x0, _x1) @ U[:, 0]).sqrt()\n",
    "_mean = otfm.bm._bridge_mean(_t, _x0.repeat((_T, 1)), _x1.repeat((_T, 1)))  @ U[:, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccfedbb9-db58-4d9e-9f92-c93f76a37268",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (3.0, 2.5))\n",
    "plt.scatter([0, ], _x0_proj[:, 0], c = 'g', zorder = 100, s = 100)\n",
    "plt.scatter([1, ], _x1_proj[:, 0], c = 'r', zorder = 100, s = 100)\n",
    "for s in range(6):\n",
    "    y0 = _mean + s*_std\n",
    "    y1 = _mean - s*_std\n",
    "    plt.fill_between(_t, y0, y1, alpha=0.15, color='purple', lw = 0)\n",
    "plt.plot(_t, xs_sde_bridge_proj[:, :, 0], alpha = 0.8, c = 'purple');\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/bridge_illustration.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ad097d9-693b-42d1-9de5-c157a5f12271",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "041ea885-c326-4760-a3fd-87445983a829",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lfm",
   "language": "python",
   "name": "lfm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
