{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7cbb8e2-0efb-40de-ba67-27c8052f0ad4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "num_threads = \"4\"\n",
    "os.environ[\"OMP_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"OPENBLAS_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"MKL_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"VECLIB_MAXIMUM_THREADS\"] = num_threads\n",
    "os.environ[\"NUMEXPR_NUM_THREADS\"] = num_threads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1710579-5a98-4f13-82ad-8866ef28f646",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchdiffeq import odeint\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "import matplotlib.pyplot as plt\n",
    "torch.set_default_dtype(torch.float32)\n",
    "import importlib\n",
    "import fm\n",
    "import torchsde\n",
    "importlib.reload(fm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3521e41d-7789-4fc0-8d05-18c5e8e92296",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "N = 128\n",
    "dim = 10\n",
    "seed = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45ed4bc1-aa74-4779-ae29-2e9dae1c8770",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data = torch.load(f\"data_seed_{seed}_dim_{dim}_N_{N}.pkl\")\n",
    "x0, x1 = data['x0'], data['x1']\n",
    "A, mu = data['A'], data['mu']\n",
    "d, sigma = data['d'], data['sigma']\n",
    "U = data['U']\n",
    "ts = torch.tensor([0., 1.])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cf8e922-cc54-4984-ae80-5b4fe6331b2c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "U.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "357be259-ced0-4450-85aa-3072a20c415d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "otfm = fm.LinearEntropicOTFM(torch.vstack([x0, x1]), \n",
    "                      torch.hstack([torch.full((x0.shape[0], ), 0), torch.full((x0.shape[0], ), 1)]), \n",
    "                      ts = ts,\n",
    "                      sigma = sigma,\n",
    "                      A = A,\n",
    "                      mu = mu,\n",
    "                      T = 2,\n",
    "                      dim = d,\n",
    "                      device = torch.device('cpu')\n",
    "                      )\n",
    "otfm_null = fm.EntropicOTFM(torch.vstack([x0, x1]), \n",
    "                      torch.hstack([torch.full((x0.shape[0], ), 0), torch.full((x0.shape[0], ), 1)]), \n",
    "                      ts = ts,\n",
    "                      sigma = sigma,\n",
    "                      T = 2,\n",
    "                      dim = d,\n",
    "                      device = torch.device('cpu')\n",
    "                      )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8748044-5239-484b-bdfc-113d8e31b072",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "s_model = fm.MLP(d = d, hidden_sizes = [64, 64, 64], time_varying=True, activation = torch.nn.ReLU)\n",
    "s_model.load_state_dict(torch.load(f\"weights/otfm_score_seed_{seed}_dim_{dim}_N_{N}.pt\"))\n",
    "v_model = fm.MLP(d = d, hidden_sizes = [64, 64, 64], time_varying=True, activation = torch.nn.ReLU)\n",
    "v_model.load_state_dict(torch.load(f\"weights/otfm_flow_seed_{seed}_dim_{dim}_N_{N}.pt\"))\n",
    "\n",
    "s_model_null = fm.MLP(d = d, hidden_sizes = [64, 64, 64], time_varying=True, activation = torch.nn.ReLU)\n",
    "s_model_null.load_state_dict(torch.load(f\"weights/otfm_null_score_seed_{seed}_dim_{dim}_N_{N}.pt\"))\n",
    "v_model_null = fm.MLP(d = d, hidden_sizes = [64, 64, 64], time_varying=True, activation = torch.nn.ReLU)\n",
    "v_model_null.load_state_dict(torch.load(f\"weights/otfm_null_flow_seed_{seed}_dim_{dim}_N_{N}.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a189fde-b752-489e-a1bd-15874b7f35b2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def _project(sb_means, sb_vars, S_t, d_sb_means, U):\n",
    "    U_expanded = U.unsqueeze(0).expand(sb_means.shape[0], *U.shape)\n",
    "    sb_means_proj = torch.bmm(sb_means.mT, U_expanded).squeeze()\n",
    "    sb_vars_proj = torch.bmm(U_expanded.mT, torch.bmm(sb_vars, U_expanded))\n",
    "    S_t_proj = torch.bmm(U_expanded.mT, torch.bmm(S_t, U_expanded))\n",
    "    d_sb_means_proj = torch.bmm(d_sb_means.mT, U_expanded).squeeze()\n",
    "    return sb_means_proj, sb_vars_proj, S_t_proj, d_sb_means_proj\n",
    "\n",
    "def _project_traj(xs, U):\n",
    "    return torch.bmm(xs, U.unsqueeze(0).expand(xs.shape[0], *U.shape));\n",
    "\n",
    "_v, _s, what = v_model, s_model, \"OU\"\n",
    "# _v, _s, what = v_model_null, s_model_null, \"BM\"\n",
    "xlims = (-5, 5)\n",
    "ylims = (-8, 10)\n",
    "\n",
    "x0_proj, x1_proj = x0 @ U, x1 @ U\n",
    "sde = fm.SDE(lambda t, x: _v(t, x) + sigma**2 / 2 * _s(t, x), sigma)\n",
    "_T = 100\n",
    "with torch.no_grad():\n",
    "    xs_sde = torchsde.sdeint(sde, torch.tensor(x0), torch.linspace(0, 1, _T), method = \"euler\")\n",
    "    xs_ode = odeint(lambda t, x: _v(t, x), torch.tensor(x0), torch.linspace(0, 1, _T))\n",
    "xs_sde_proj = _project_traj(xs_sde, U)\n",
    "xs_ode_proj = _project_traj(xs_ode, U)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71481868-b6e3-46b1-a39c-54479325ba8d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "k = 1\n",
    "plt.figure(figsize = (5, 2.5))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(torch.linspace(0, 1, _T), xs_sde_proj[:, range(100), k] , color = 'blue', alpha = 0.1);\n",
    "plt.scatter(torch.zeros_like(x0 @ U[:, k]), x0 @ U[:, k], c = 'green', label = \"$p_0$\")\n",
    "plt.scatter(torch.ones_like(x1 @ U[:, k]), x1 @ U[:, k], c = 'red', label = \"$p_1$\")\n",
    "plt.title(\"SDE\")\n",
    "plt.legend()\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x$\")\n",
    "plt.ylim(*ylims)\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(torch.linspace(0, 1, _T), xs_ode_proj[:, range(100), k], color = 'blue', alpha = 0.1);\n",
    "plt.scatter(torch.zeros_like(x0 @ U[:, k]), x0 @ U[:, k], c = 'green', label = \"$p_0$\")\n",
    "plt.scatter(torch.ones_like(x1 @ U[:, k]), x1 @ U[:, k], c = 'red', label = \"$p_1$\")\n",
    "plt.legend()\n",
    "plt.title(\"PF-ODE\")\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x$\")\n",
    "plt.ylim(*ylims)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/SB_Gaussian_{what}_1d_v_time_seed_{seed}_dim_{dim}_N_{N}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7009e32-94e5-4095-a6d1-b2dabf04ceba",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "xlims = (-5, 5)\n",
    "ylims = (-7.5, 7.5)\n",
    "x = np.linspace(*xlims, 100)\n",
    "y = np.linspace(*ylims, 100)\n",
    "X, Y = np.meshgrid(x, y)\n",
    "u, v = U.numpy().T @ (A.numpy() @ ((U.numpy() @ np.vstack([X.flatten(), Y.flatten()])) - mu.numpy()[:, None]))\n",
    "u = u.reshape(X.shape)\n",
    "v = v.reshape(X.shape)\n",
    "# don't plot origin\n",
    "_idx = (X**2 < 1) & (Y**2 < 1)\n",
    "u[_idx] = 0\n",
    "v[_idx] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad1050ae-120d-46dc-b759-d08c21708b07",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (5, 2.5))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.streamplot(X, Y, u, v, density=0.3, color='k', linewidth=0.5, arrowsize=1)\n",
    "plt.plot(xs_sde_proj[:, range(100), 0], xs_sde_proj[:, range(100), 1], color = 'blue', alpha = 0.1);\n",
    "plt.scatter(x0_proj[:, 0], x0_proj[:, 1], c = 'green', label = \"$p_0$\")\n",
    "plt.scatter(x1_proj[:, 0], x1_proj[:, 1], c = 'red', label = \"$p_1$\")\n",
    "# plt.legend()\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\")\n",
    "plt.title(\"SDE\")\n",
    "plt.xlim(*xlims); plt.ylim(*ylims)\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.streamplot(X, Y, u, v, density=0.3, color='k', linewidth=0.5, arrowsize=1)\n",
    "plt.plot(xs_ode_proj[:, range(100), 0], xs_ode_proj[:, range(100), 1], color = 'blue', alpha = 0.1);\n",
    "plt.scatter(x0_proj[:, 0], x0_proj[:, 1], c = 'green', label = \"$p_0$\")\n",
    "plt.scatter(x1_proj[:, 0], x1_proj[:, 1], c = 'red', label = \"$p_1$\")\n",
    "# plt.legend()\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\")\n",
    "plt.title(\"PF-ODE\")\n",
    "plt.xlim(*xlims); plt.ylim(*ylims)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/SB_Gaussian_{what}_2d_seed_{seed}_dim_{dim}_N_{N}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b206c186-724d-4a88-99d0-fce994e653fa",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# varying noise level in forward simulation\n",
    "xs_sde_sigma = {}\n",
    "for _sigma in [0.01, 0.5, 1.0, 2.5]:\n",
    "    sde = fm.SDE(lambda t, x: _v(t, x) + _sigma**2 / 2 * _s(t, x), _sigma)\n",
    "    _T = 100\n",
    "    with torch.no_grad():\n",
    "        xs_sde_sigma[_sigma] = torchsde.sdeint(sde, torch.tensor(x0, dtype = torch.float32), torch.linspace(0, 1, _T), method = \"euler\")\n",
    "\n",
    "plt.figure(figsize = (10, 2.5))\n",
    "for i, (_sigma, _xs) in enumerate(xs_sde_sigma.items()):\n",
    "    _xs_proj = _project_traj(_xs, U)\n",
    "    plt.subplot(1, 4, i+1)\n",
    "    plt.streamplot(X, Y, u, v, density=0.3, color='k', linewidth=0.5, arrowsize=1)\n",
    "    plt.plot(_xs_proj[:, range(100), 0], _xs_proj[:, range(100), 1], color = 'blue', alpha = 0.1);\n",
    "    plt.scatter(x0_proj[:, 0], x0_proj[:, 1], c = 'green')\n",
    "    plt.scatter(x1_proj[:, 0], x1_proj[:, 1], c = 'red')\n",
    "    plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\")\n",
    "    plt.xlim(*xlims); plt.ylim(*ylims)\n",
    "    plt.title(f\"$\\\\sigma = {_sigma}$\")\n",
    "    plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/SB_Gaussian_{what}_2d_var_noise_seed_{seed}_dim_{dim}_N_{N}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca244887-5e6a-4f37-9af1-6fd66b9898b0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "t = torch.linspace(0, 1, 25)\n",
    "sde = fm.SDE(lambda t, x: v_model(t, x) + sigma**2 / 2 * s_model(t, x), sigma)\n",
    "sde_null = fm.SDE(lambda t, x: v_model_null(t, x) + sigma**2 / 2 * s_model_null(t, x), sigma)\n",
    "with torch.no_grad():\n",
    "    xs_sde = torchsde.sdeint(sde, torch.tensor(x0), t, method = \"euler\")\n",
    "    xs_sde_null = torchsde.sdeint(sde_null, torch.tensor(x0), t, method = \"euler\")\n",
    "xs_sde_proj = torch.bmm(xs_sde, U.unsqueeze(0).expand(xs_sde.shape[0], *U.shape));\n",
    "xs_sde_null_proj = torch.bmm(xs_sde_null, U.unsqueeze(0).expand(xs_sde_null.shape[0], *U.shape));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35e5dc32-c26b-409d-8978-0e3e1dc199e7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "importlib.reload(fm)\n",
    "mean0, mean1 = data['mean0'], data['mean1']\n",
    "mean0_proj, mean1_proj = mean0 @ U, mean1 @ U\n",
    "var0, var1 = data['var0'], data['var1']\n",
    "var0_proj, var1_proj = U.T @ var0 @ U, U.T @ var1 @ U\n",
    "\n",
    "gsb = fm.GaussianOUSB(otfm.bm, otfm)\n",
    "sb_means, sb_vars, S_t, d_sb_means = gsb.evaluate(t, mean0, mean1, var0, var1)\n",
    "sb_means_proj, sb_vars_proj, S_t_proj, d_sb_means_proj = _project(sb_means, sb_vars, S_t, d_sb_means, U)\n",
    "xs_sde_proj = _project_traj(xs_sde, U)\n",
    "x0_proj, x1_proj = x0 @ U, x1 @ U\n",
    "_x = torch.tensor(np.vstack([X.flatten(), Y.flatten()]), dtype = torch.float32)\n",
    "vs = [d_sb_means[i] + (S_t[i].T @ torch.linalg.pinv(sb_vars[i])) @ ((U @ _x) - sb_means[i]) for i in range(len(t))]\n",
    "vs_proj = [U.T @ v for v in vs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2da35aed-2cde-4d6e-b51b-673de8ef4231",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_otfm_null = fm.LinearEntropicOTFM(torch.vstack([x0, x1]), \n",
    "                      torch.hstack([torch.full((x0.shape[0], ), 0), torch.full((x0.shape[0], ), 1)]), \n",
    "                      ts = ts,\n",
    "                      sigma = sigma,\n",
    "                      A = 0*A,\n",
    "                      mu = 0*mu,\n",
    "                      T = 2,\n",
    "                      dim = d,\n",
    "                      device = torch.device('cpu')\n",
    "                      )\n",
    "\n",
    "gsb_null = fm.GaussianOUSB(_otfm_null.bm, _otfm_null)\n",
    "sb_means_null, sb_vars_null, S_t_null, d_sb_means_null = gsb_null.evaluate(t, mean0, mean1, var0, var1)\n",
    "sb_means_proj_null, sb_vars_proj_null, S_t_proj_null, d_sb_means_proj_null = _project(sb_means_null, sb_vars_null, S_t_null, d_sb_means_null, U)\n",
    "xs_sde_proj_null = _project_traj(xs_sde_null, U)\n",
    "vs_null = [d_sb_means_null[i] + (S_t_null[i].T @ torch.linalg.pinv(sb_vars_null[i])) @ ((U @ _x) - sb_means_null[i]) for i in range(len(t))]\n",
    "vs_proj_null = [U.T @ v for v in vs_null]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83f0270f-e4c0-4db7-bb69-0827f89a8aa8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "\n",
    "def get_cmap(solid_color):\n",
    "    cdict = {\n",
    "        'red':   [(0.0, 1.0, 1.0),  # Start at white (1.0, 1.0, 1.0)\n",
    "                  (1.0, solid_color[0], solid_color[0])],  # End at solid color\n",
    "\n",
    "        'green': [(0.0, 1.0, 1.0),\n",
    "                  (1.0, solid_color[1], solid_color[1])],\n",
    "\n",
    "        'blue':  [(0.0, 1.0, 1.0),\n",
    "                  (1.0, solid_color[2], solid_color[2])]\n",
    "    }\n",
    "    return matplotlib.colors.LinearSegmentedColormap('WhiteToBlue', cdict)\n",
    "\n",
    "\n",
    "from scipy.stats import multivariate_normal\n",
    "x = np.linspace(*xlims, 100)\n",
    "y = np.linspace(*ylims, 100)\n",
    "X, Y = np.meshgrid(x, y)\n",
    "_X = np.vstack([X.reshape(-1), Y.reshape(-1)])\n",
    "pos = np.dstack((X, Y))\n",
    "\n",
    "def plot_bivariate(mean, cov, cm, **kwargs):\n",
    "    rv = multivariate_normal(mean, cov)\n",
    "    Z = rv.pdf(pos)\n",
    "    plt.contour(X, Y, Z, levels=5, cmap=cm, **kwargs)\n",
    "\n",
    "plt.figure(figsize = (3, 1.5))\n",
    "plt.subplot(1, 2, 1)\n",
    "for i in range(0, len(t), 3):\n",
    "    cm = get_cmap(matplotlib.cm.winter(t[i].item()))\n",
    "    plot_bivariate(sb_means_proj[i].flatten(), sb_vars_proj[i].reshape(2, 2), cm = cm)\n",
    "plt.streamplot(X, Y, u, v, density=0.3, color='k', linewidth=0.5, arrowsize=1, zorder = -100)\n",
    "plt.xlim(*xlims); plt.ylim(*ylims)\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\"); plt.axis('off')\n",
    "plt.title(\"OU-GSB\")\n",
    "plt.subplot(1, 2, 2)\n",
    "for i in range(0, len(t), 3):\n",
    "    cm = get_cmap(matplotlib.cm.winter(t[i].item()))\n",
    "    plot_bivariate(sb_means_proj_null[i].flatten(), sb_vars_proj_null[i].reshape(2, 2), cm = cm)\n",
    "plt.streamplot(X, Y, u, v, density=0.3, color='k', linewidth=0.5, arrowsize=1, zorder = -100)\n",
    "plt.title(\"BM-GSB\")\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\"); plt.axis('off')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/SB_Gaussian_{what}_2d_gaussian_fm_seed_{seed}_dim_{dim}_N_{N}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f78383ac-3797-4587-b196-848a6805a28d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "torch.vmap(lambda x: (x - x.T).norm())(torch.bmm(S_t_null.mT, torch.linalg.pinv(sb_vars_null)))\n",
    "torch.vmap(lambda x: x.norm())(torch.bmm(S_t.mT, torch.linalg.pinv(sb_vars)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c94ea4d7-cdf2-403c-ab6f-587f59c20d95",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (9, 1.75))\n",
    "plt.subplot(1, 6, 1)\n",
    "for i in range(0, len(t), 3):\n",
    "    cm = get_cmap(matplotlib.cm.winter(t[i].item()))\n",
    "    plot_bivariate(sb_means_proj_null[i].flatten(), sb_vars_proj_null[i].reshape(2, 2), cm = cm)\n",
    "plt.streamplot(X, Y, u, v, density=0.3, color='k', linewidth=0.5, arrowsize=1, zorder = -100)\n",
    "plt.title(\"BM-GSB\")\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\"); plt.axis('off')\n",
    "for j, i in enumerate(range(0, len(t), 6)):\n",
    "    plt.subplot(1, 6, j+2)\n",
    "    plt.streamplot(X, Y, vs_proj_null[i][0, :].reshape(X.shape), vs_proj_null[i][1, :].reshape(X.shape), density=0.3, color='k', linewidth=0.5, arrowsize=1)\n",
    "    plot_bivariate(sb_means_proj_null[i].flatten(), sb_vars_proj_null[i].reshape(2, 2), cm = get_cmap(matplotlib.cm.winter(t[i].item())))\n",
    "    # plt.arrow(sb_means_proj[i, 0].item(), sb_means_proj[i, 1].item(), 0.25*d_sb_means_proj[i, 0].item(), 0.25*d_sb_means_proj[i, 1].item(), head_width=0.2, head_length=0.3, fc='blue', ec='blue', zorder = 100)\n",
    "    plt.xlim(*xlims); plt.ylim(*ylims)\n",
    "    plt.axis(\"off\")\n",
    "    plt.title(f\"$t$ = {t[i]:.2f}\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/SB_Gaussian_{what}_2d_gaussian_marginals_BM_seed_{seed}_dim_{dim}.pdf\")\n",
    "\n",
    "plt.figure(figsize = (9, 1.75))\n",
    "plt.subplot(1, 6, 1)\n",
    "for i in range(0, len(t), 3):\n",
    "    cm = get_cmap(matplotlib.cm.winter(t[i].item()))\n",
    "    plot_bivariate(sb_means_proj[i].flatten(), sb_vars_proj[i].reshape(2, 2), cm = cm)\n",
    "plt.streamplot(X, Y, u, v, density=0.3, color='k', linewidth=0.5, arrowsize=1, zorder = -100)\n",
    "plt.xlim(*xlims); plt.ylim(*ylims)\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\"); plt.axis('off')\n",
    "plt.title(\"mvOU-GSB\")\n",
    "for j, i in enumerate(range(0, len(t), 6)):\n",
    "    plt.subplot(1, 6, j+2)\n",
    "    plt.streamplot(X, Y, vs_proj[i][0, :].reshape(X.shape), vs_proj[i][1, :].reshape(X.shape), density=0.3, color='k', linewidth=0.5, arrowsize=1)\n",
    "    plot_bivariate(sb_means_proj[i].flatten(), sb_vars_proj[i].reshape(2, 2), cm = get_cmap(matplotlib.cm.winter(t[i].item())))\n",
    "    # plt.arrow(sb_means_proj[i, 0].item(), sb_means_proj[i, 1].item(), 0.25*d_sb_means_proj[i, 0].item(), 0.25*d_sb_means_proj[i, 1].item(), head_width=0.2, head_length=0.3, fc='blue', ec='blue', zorder = 100)\n",
    "    plt.xlim(*xlims); plt.ylim(*ylims)\n",
    "    plt.axis(\"off\")\n",
    "    plt.title(f\"$t$ = {t[i]:.2f}\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/SB_Gaussian_{what}_2d_gaussian_marginals_OU_seed_{seed}_dim_{dim}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e712110b-5ed2-4237-869e-34e6ea03405c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for i in range(0, len(t), 2):\n",
    "    cm = get_cmap(matplotlib.cm.cool(t[i].item()))\n",
    "    plt.scatter(xs_sde_proj[i][:, 0], xs_sde_proj[i][:, 1], color = matplotlib.cm.cool(t[i].item()), alpha = 0.3, marker = '^', label = f\"Sampled: t = {t[i]}\")\n",
    "    plot_bivariate(sb_means_proj[i].flatten(), sb_vars_proj[i].reshape(2, 2), cm = cm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eefb5453-0ab7-4ba2-8c9d-23d27b0d75ff",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sb_means, sb_vars, S_t, d_sb_means = gsb.evaluate(gsb.bm.ts, mean0, mean1, var0, var1)\n",
    "sb_means_proj, sb_vars_proj, S_t_proj, d_sb_means_proj = _project(sb_means, sb_vars, S_t, d_sb_means, U)\n",
    "\n",
    "def _F(t, x):\n",
    "    _S_t = gsb.bm.interp(S_t, t)\n",
    "    _var_t = gsb.bm.interp(sb_vars, t)\n",
    "    _mean_t = gsb.bm.interp(sb_means, t)\n",
    "    _dmean_t = gsb.bm.interp(d_sb_means, t)\n",
    "    return _dmean_t + _S_t.T @ torch.linalg.pinv(_var_t) @ (x - _mean_t)\n",
    "\n",
    "sde_gt = fm.SDE(_F, sigma)\n",
    "with torch.no_grad():\n",
    "    xs_sde_gt = torchsde.sdeint(sde, torch.tensor(x0), t, method = \"euler\")\n",
    "xs_sde_gt_proj = _project_traj(xs_sde_gt, U)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0f68c61-c82a-4ee2-87e7-dc6f5227c9d5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.plot(xs_sde_gt_proj[:, range(100), 0], xs_sde_gt_proj[:, range(100), 1], color = 'blue', alpha = 0.1);\n",
    "for i in list(range(0, len(gsb.bm.ts), 5)) + [len(gsb.bm.ts)-1, ]:\n",
    "    cm = get_cmap(matplotlib.cm.cool(gsb.bm.ts[i].item()))\n",
    "    plot_bivariate(sb_means_proj[i].flatten(), sb_vars_proj[i].reshape(2, 2), cm = cm)\n",
    "plt.scatter(x0_proj[:, 0], x0_proj[:, 1], c = 'green', alpha = 0.75)\n",
    "plt.scatter(x1_proj[:, 0], x1_proj[:, 1], c = 'red', alpha = 0.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cd39950-9fb6-4014-a231-988db2462b75",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "result = torch.load(f\"weights/gp_sinkhorn_seed_{seed}_dim_{dim}_N_{N}.pkl\")\n",
    "fig, (ax1, ax2) =  plt.subplots(1, 2,figsize=(14,6))\n",
    "M = result[-1][1][..., :-1] @ U.double()\n",
    "M2 = result[-1][3][..., :-1] @ U.double()\n",
    "for i in range(len(M)): \n",
    "    μ10 = M[i,:,0].detach()\n",
    "    μ11 = M[i,:,1].detach()\n",
    "    μ20 = M2[i,:,0].detach()\n",
    "    μ21 = M2[i,:,1].detach()\n",
    "    ax1.plot(μ10, μ11, alpha=.3, color=\"red\")\n",
    "    ax2.plot(μ20, μ21, alpha=.3, color=\"red\")\n",
    "ax1.scatter(M[:,0,0].detach(),M[:,0,1].detach(), zorder = 100)\n",
    "ax2.scatter(M2[:,0,0].detach(),M2[:,0,1].detach(), zorder = 100)\n",
    "ax1.set_title(\"Forward\")\n",
    "ax2.set_title(\"Backward\")\n",
    "ax1.set_xlim(*xlims); ax1.set_ylim(*ylims)\n",
    "ax2.set_xlim(*xlims); ax2.set_ylim(*ylims)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "547da702-5698-4a74-a075-3dd290edcd84",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import glob\n",
    "import seaborn as sb\n",
    "files = glob.glob(f\"eval/eval_seed_*_dim_*_N_{N}.csv\")\n",
    "seeds = [int(os.path.basename(f).split(\"_\")[2]) for f in files]\n",
    "dims = [int(os.path.basename(f).split(\"_\")[4]) for f in files]\n",
    "df = pd.concat([pd.DataFrame(pd.read_csv(f, index_col = 0).iloc[:, 1:].mean(0)).T.assign(seed=s).assign(dim=d).assign(file=f) for (f, s, d) in zip(files, seeds, dims)])\n",
    "_df = pd.concat([pd.DataFrame(pd.read_csv(f\"eval/eval_nlsb_seed_{s}_dim_{d}_N_{N}.csv\", index_col = 0).iloc[:, 1:].mean(0)).T for (f, s, d) in zip(files, seeds, dims)])\n",
    "df = pd.concat([df, _df], axis = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57fcbd42-384c-4c95-a53a-5e113d684e53",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_df1 = df.melt(id_vars = ['seed', 'dim', 'file'], value_vars=['OU_EOT_BW', 'EOT_BW', 'GP_sinkhorn_fwd_BW', 'GP_sinkhorn_rev_BW', 'NLSB_BW', ])\n",
    "_df1 = _df1.replace({'variable' : {'OU_EOT_BW' : 'OU-GSB', \n",
    "                                'EOT_BW' : 'BM-GSB', \n",
    "                                'GP_sinkhorn_fwd_BW' : 'IPML(→)', \n",
    "                                'GP_sinkhorn_rev_BW' : 'IPML(←)', \n",
    "                                'NLSB_BW' : 'NLSB', \n",
    "                                  }})\n",
    "_df1=_df1.rename(columns={'variable' : 'method',})\n",
    "_df1.dim = _df1.dim.astype('category')\n",
    "plt.figure(figsize = (7, 3.5))\n",
    "plt.subplot(1, 2, 1)\n",
    "sb.boxplot(_df1, x = \"method\", y = \"value\", hue = \"dim\", palette = \"Blues\")\n",
    "plt.title(\"Schrödinger bridge marginals\")\n",
    "plt.ylabel(\"$\\\\rm{BW}_2^2$\")\n",
    "\n",
    "_df2 = df.melt(id_vars = ['seed', 'dim', 'file'], value_vars=['OU_EOT_vf', 'EOT_vf', 'GP_sinkhorn_vf', 'NLSB_vf', ])\n",
    "_df2 = _df2.replace({'variable' : {'OU_EOT_vf' : 'OU-GSB', \n",
    "                                   'EOT_vf' : 'BM-GSB',\n",
    "                                   'GP_sinkhorn_vf' : 'IPML',\n",
    "                                   'NLSB_vf' : 'NLSB', \n",
    "                                  }})\n",
    "_df2=_df2.rename(columns={'variable' : 'method',})\n",
    "_df2.dim = _df2.dim.astype('category')\n",
    "plt.subplot(1, 2, 2)\n",
    "sb.boxplot(_df2, x = \"method\", y = \"value\", hue = \"dim\", palette = \"Blues\")\n",
    "plt.ylabel(\"Vector field RMSE\")\n",
    "plt.title(\"Schrödinger bridge vector field\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/SB_Gaussian_benchmarking.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d895cec-d35a-4be3-b114-77d6b52e77c0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from toolz import interleave\n",
    "_df1_mean = _df1.groupby(['dim', 'method'])[[\"value\"]].agg({'value' : ['mean', ]}).unstack()\n",
    "_df1_std = _df1.groupby(['dim', 'method'])[[\"value\"]].agg({'value' : ['std', ]}).unstack()\n",
    "_df1_mean_str = _df1_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df1_std_str = _df1_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df1_str = pd.DataFrame({_df1_mean_str.index[i] : _df1_mean_str.iloc[:, i].str.cat(_df1_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df1_mean_str.shape[1])})\n",
    "_df1_str.columns = _df1_mean.columns\n",
    "for i, j in enumerate(np.argmin(_df1_mean.values, 1)):\n",
    "    _df1_str.iloc[i, j] = \"\\\\textbf{\" + _df1_str.iloc[i, j] + \"}\"\n",
    "# print(_df1_str.to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c70b745-d958-42dc-a302-fd30d5a12843",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_df2_mean = _df2.groupby(['dim', 'method'])[[\"value\"]].agg({'value' : ['mean', ]}).unstack()\n",
    "_df2_std = _df2.groupby(['dim', 'method'])[[\"value\"]].agg({'value' : ['std', ]}).unstack()\n",
    "_df2_mean_str = _df2_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df2_std_str = _df2_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df2_str = pd.DataFrame({_df2_mean_str.index[i] : _df2_mean_str.iloc[:, i].str.cat(_df2_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df2_mean_str.shape[1])})\n",
    "_df2_str.columns = _df2_mean.columns\n",
    "for i, j in enumerate(np.argmin(_df2_mean.values, 1)):\n",
    "    _df2_str.iloc[i, j] = \"\\\\textbf{\" + _df2_str.iloc[i, j] + \"}\"\n",
    "# # # # # # # # # print(_df2_str.to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2ef2fc8-8b88-44b6-abea-16170384e4cf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(pd.concat([_df1_str.iloc[:, [4, 0, 1, 2, 3]], _df2_str.iloc[:, [3, 0, 1, 2]]], axis = 1).to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a64f1f91-e9c8-452d-a99e-299fcb1af443",
   "metadata": {},
   "outputs": [],
   "source": [
    "files = glob.glob(f\"eval/eval_seed_*_dim_*_N_{N}.csv\")\n",
    "seeds = [int(os.path.basename(f).split(\"_\")[2]) for f in files]\n",
    "dims = [int(os.path.basename(f).split(\"_\")[4]) for f in files]\n",
    "df = pd.concat([pd.DataFrame(pd.read_csv(f, index_col = 0)).assign(seed=s).assign(dim=d).assign(file=f) for (f, s, d) in zip(files, seeds, dims)])\n",
    "_df = pd.concat([pd.DataFrame(pd.read_csv(f\"eval/eval_nlsb_seed_{s}_dim_{d}_N_{N}.csv\", index_col = 0).iloc[:, 1:]) for (f, s, d) in zip(files, seeds, dims)])\n",
    "df = pd.concat([df, _df], axis = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd26ec08-718b-4865-9129-ce64db8514aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "_d = np.sort(df.dim.unique())\n",
    "_df = df.melt(id_vars = ['t', 'dim'], value_vars = ['OU_EOT_BW', 'EOT_BW', 'NLSB_BW', 'GP_sinkhorn_fwd_BW', 'GP_sinkhorn_rev_BW'])\n",
    "_df = _df.replace({'variable' : {'OU_EOT_BW' : 'OU-GSB', \n",
    "                                'EOT_BW' : 'BM-GSB', \n",
    "                                'NLSB_BW' : 'NLSB', \n",
    "                                'GP_sinkhorn_fwd_BW' : 'IPML(→)', \n",
    "                                'GP_sinkhorn_rev_BW' : 'IPML(←)'}})\n",
    "g = sb.FacetGrid(_df, col=\"dim\", hue = 'variable')\n",
    "g.map_dataframe(sb.lineplot, x = 't', y = 'value')\n",
    "g.add_legend()\n",
    "g.set_ylabels('$\\\\rm{BW}_2^2$')\n",
    "g.set_xlabels('$t$')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ecdf60b-20be-42e3-81e8-d6f82933c392",
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f49501a-a4ce-4c0b-b3c9-4f3e882b5df8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lfm",
   "language": "python",
   "name": "lfm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
