{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0144fd10-bfc3-4e1f-bbc8-a8400f08209b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import autograd\n",
    "import autograd.numpy as np\n",
    "import torchdiffeq\n",
    "from torchdiffeq import odeint\n",
    "import geomloss\n",
    "from tqdm import tqdm\n",
    "import importlib\n",
    "import math\n",
    "import torchsde\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.set_default_dtype(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f4918c1-8e06-4246-b3bc-25925dd6873d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "N = 500\n",
    "T = 5\n",
    "dim = 10\n",
    "betamax = 5.0\n",
    "seed = 1\n",
    "reg = 'vf'\n",
    "data = torch.load(f\"sim_twowell_N_{N}_T_{T}_dim_{dim}_D_0.25_beta_{betamax}.pkl\", weights_only = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "164d1237-8d48-4f38-8017-39df043038bf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sklearn as sk\n",
    "from sklearn import decomposition\n",
    "import matplotlib.pyplot as plt\n",
    "plt.figure(figsize = (2.5, 2.5))\n",
    "i=0\n",
    "plt.scatter(data['x'][:, i], data['x'][:, i+1], c = data['t_idx'], alpha = 0.25, s = 10, rasterized = True)\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\")\n",
    "plt.tight_layout()\n",
    "plt.title(\"Data\")\n",
    "plt.savefig(f\"../../figures/twowell_data_x0x1_dim_{dim}_beta_{betamax}.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f227be33-4f54-4384-ae31-bd153e7ba40f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "T = len(np.unique(data['t_idx']))\n",
    "dim = data[\"x\"].shape[1]\n",
    "ts = np.linspace(0, data['t_final'], T)\n",
    "\n",
    "plt.figure(figsize = (2.5, 2.5))\n",
    "plt.scatter(ts[data['t_idx']] + np.random.randn(len(data['t_idx']))*0.01, data['x'][:, 0], alpha = 0.05, s = 1, color = \"red\", rasterized = True)\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x_0$\")\n",
    "plt.title(\"Data\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/twowell_data_dim_{dim}_beta_{betamax}.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d80af493-3a49-40bf-ad6a-fb67ccb6f86c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "name = \"default_teacherforcing\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b4549ee-8196-45b4-8807-81ac32667e57",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "import models, utils\n",
    "from torch import optim\n",
    "importlib.reload(models)\n",
    "\n",
    "# Score fitting \n",
    "X = [torch.tensor(data[\"x\"][data[\"t_idx\"] == i, :], device = device, dtype = torch.float32) for i in np.sort(np.unique(data[\"t_idx\"]))]\n",
    "s = models.NCScoreFunc(d = dim, hidden_sizes = [64, 64, 64], activation = torch.nn.ReLU, time_dependent = True).to(device)\n",
    "# sigmas = torch.linspace(1, -3, 5, device = device).exp()\n",
    "sigmas = torch.linspace(0, -2, 5, device = device).exp()\n",
    "s.load_state_dict(torch.load(f'weights/params_NCScoreFunc_{name}_reg_{reg}_N_{N}_T_{T}_dim{dim}_beta_{betamax}_seed_{seed}_final.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8216d98-c976-4f08-8ce0-24087b8c04d6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "P = torch.vstack([torch.eye(2), torch.zeros(dim-2, 2)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2742d7d-60f6-49b1-b5a8-eb5fd21f9ee1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "s.to(device)\n",
    "samplers = [models.LangevinSampler(lambda x, sigma, _t = torch.scalar_tensor(_s).to(device): s(_t, x, sigma), \n",
    "                      torch.randn(1_000, dim).to(device), \n",
    "                      sigmas = sigmas, dt = 3e-3, n_iter = 5000) for _s in ts]\n",
    "x_sample = torch.vstack([s.sample().cpu()[None, ...] for s in samplers])\n",
    "X_all = torch.vstack(X).cpu()\n",
    "s.cpu()\n",
    "x_min, x_max = (X_all[:, 0]).min()-1, (X_all[:, 0]).max()+1\n",
    "y_min, y_max = (X_all[:, 1]).min()-1, (X_all[:, 1]).max()+1\n",
    "\n",
    "plt.figure(figsize = (10, 2.5))\n",
    "for i in range(len(X)):\n",
    "    x, y = torch.meshgrid(torch.linspace(x_min, x_max, 15), torch.linspace(y_min, y_max, 15))\n",
    "    _x = torch.vstack([x.flatten(), y.flatten()]).T\n",
    "    t = torch.scalar_tensor(ts[i])\n",
    "    plt.subplot(1,5, i+1)\n",
    "    with torch.no_grad():\n",
    "        _v = s(t, torch.tensor(_x @ P.T, dtype = torch.float32), sigmas[-1].cpu()) @ P\n",
    "        # _v = s(t, torch.tensor(_x @ P.T, dtype = torch.float32)) @ P\n",
    "    # plt.contourf(x, y, torch.linalg.norm(_v, dim = 1).reshape(x.shape), levels = 20, cmap = \"bone\");\n",
    "    plt.quiver(_x[:, 0], _x[:, 1], _v[:, 0], _v[:, 1], torch.linalg.norm(_v, dim = 1), cmap='Blues', linewidth = 1, scale_units = 'xy')\n",
    "    plt.scatter(X[i].cpu()[:, 0], X[i].cpu()[:, 1], s = 0.1, c = 'r', alpha = 1, rasterized = True)\n",
    "    plt.scatter(x_sample[i, :, 0].cpu(), x_sample[i, :, 1], s = 5, c = 'b', alpha = 0.1, rasterized = True)\n",
    "    plt.xlim(x_min, x_max); plt.ylim(y_min, y_max)\n",
    "    plt.axis(\"off\")\n",
    "    plt.title(f\"t = {ts[i]:.2f}\")\n",
    "    plt.xlabel(\"x\"); plt.ylabel(\"x\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/twowell_score_dim_{dim}_beta_{betamax}.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "878d6530-e6c1-4cf0-98e6-19d27dea1c37",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "importlib.reload(models)\n",
    "ts = torch.linspace(0, 1, len(X), device = device)\n",
    "dt = ts[1]-ts[0]\n",
    "s = s.to(device)\n",
    "m_ratios = [x.shape[0] / X[0].shape[0] for x in X]\n",
    "D = 0.5**2\n",
    "hidden_sizes = [64, 64, 64]\n",
    "odeint_options = {'method' : 'euler', 'options' : {'step_size' : 0.1}}\n",
    "\n",
    "v_upfi = models.ODEFlowGrowth(d = dim, kwargs_v = {'hidden_sizes' : hidden_sizes, 'time_dependent' : False}, \n",
    "                                       kwargs_g = {'hidden_sizes' : hidden_sizes, 'time_dependent' : False}).to(device)\n",
    "v_upfi.load_state_dict(torch.load(f'weights/params_UPFI_ODEFlowGrowth_{name}_reg_{reg}_N_{N}_T_{T}_dim{dim}_beta_{betamax}_seed_{seed}_final.pt'))\n",
    "v_pfi = models.VectorField(d = dim, hidden_sizes = hidden_sizes, time_dependent = True).to(device)\n",
    "v_pfi.load_state_dict(torch.load(f'weights/params_PFI_VectorField_{name}_reg_{reg}_N_{N}_T_{T}_dim{dim}_beta_{betamax}_seed_{seed}_final.pt'))\n",
    "v_ode = models.ODEFlowGrowthCoupled(d = dim, hidden_sizes = hidden_sizes, time_dependent = True).to(device)\n",
    "v_ode.load_state_dict(torch.load(f'weights/params_ODE_ODEFlowGrowthCoupled_{name}_reg_{reg}_N_{N}_T_{T}_dim{dim}_beta_{betamax}_seed_{seed}_final.pt'))\n",
    "v_tigon = models.ODEFlowGrowth(d = dim, kwargs_v = {'hidden_sizes' : hidden_sizes, 'time_dependent' : True}, \n",
    "                                       kwargs_g = {'hidden_sizes' : hidden_sizes, 'time_dependent' : True}).to(device)\n",
    "v_tigon.load_state_dict(torch.load(f'weights/params_TIGON_ODEFlowGrowth_{name}_reg_{reg}_N_{N}_T_{T}_dim{dim}_beta_{betamax}_seed_{seed}_final.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0b919ba-d8bc-4446-a971-0485ff923116",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "f'weights/params_UPFI_ODEFlowGrowth_{name}_reg_{reg}_T_{T}_dim{dim}_beta_{betamax}_seed_{seed}_final.pt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d46e1ade-f8f9-43fe-adce-d6d4c5da4a4e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "x0 = utils.sample_batch(X, batch_size = 256, replacement = True)[0]\n",
    "x0_ode = utils.sample_batch(X, batch_size = 256, replacement = True, add_noise = True, noise_level=0.1)[0]\n",
    "x0_mass = utils.sample_batch_upfi(X, torch.tensor(m_ratios).to(device), batch_size = 256, replacement = True)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b50c1c41-966a-4229-bb89-73d9606fabaf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "v_pfi.to(device); v_upfi.to(device); v_ode.to(device); v_tigon.to(device)\n",
    "# def F_ode_upfi(t, x):\n",
    "#     return v_upfi(t, x) - (D/2)*torch.hstack([torch.zeros_like(x[:, :1]), s(t, x[:, 1:], sigmas[-1]), ])                        \n",
    "# Using ODE\n",
    "# xs_t_pfi = odeint(lambda t, x: v_pfi(t, x) - (D/2)*s(t, x, sigmas[-1]), x0, ts, **odeint_options).cpu()\n",
    "# xs_t_upfi = odeint(F_ode_upfi, x0_mass.to(device), ts, **odeint_options).cpu()\n",
    "# xs_t_ode = odeint(v_ode, x0_mass_ode, ts, **odeint_options).cpu()\n",
    "\n",
    "# Using SDE\n",
    "sde_pfi = models.SDE(lambda t, x: v_pfi(t, x), sigma = D**0.5)\n",
    "sde_upfi = models.SDE(lambda t, x: v_upfi(t, x), sigma = torch.cat([torch.tensor([0, ]), torch.full((dim, ), D**0.5)]).to(device))\n",
    "with torch.no_grad():\n",
    "    xs_t_pfi = torchsde.sdeint(sde_pfi, x0, ts, method = \"euler\").cpu()\n",
    "    xs_t_upfi = torchsde.sdeint(sde_upfi, x0_mass, ts, method = \"euler\").cpu()\n",
    "    xs_t_tigon = odeint(v_tigon, x0_mass, ts, **odeint_options).cpu()\n",
    "    xs_t_ode = odeint(v_ode, x0_mass, ts, **odeint_options).cpu()\n",
    "v_pfi.cpu(); v_upfi.cpu(); v_ode.cpu(); v_tigon.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82d629c7-f842-4495-b854-8db82ffd6852",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import plotting\n",
    "plt.figure(figsize = (10, 2.5))\n",
    "for i in range(len(X)):\n",
    "    x, y = torch.meshgrid(torch.linspace(x_min, x_max, 15), torch.linspace(y_min, y_max, 15))\n",
    "    _x = torch.vstack([x.flatten(), y.flatten()]).T\n",
    "    t = torch.scalar_tensor(ts[i])\n",
    "    plt.subplot(1, 5, i+1)\n",
    "    with torch.no_grad():\n",
    "        _v = v_pfi(t, torch.tensor(_x @ P.T, dtype = torch.float32)) @ P\n",
    "    # plt.contourf(x, y, torch.linalg.norm(_v, dim = 1).reshape(x.shape), levels = 20);\n",
    "    # plt.colorbar()\n",
    "    plt.quiver(_x[:, 0], _x[:, 1], _v[:, 0], _v[:, 1], torch.clamp_max(torch.linalg.norm(_v, dim = 1), 10), cmap='Blues', linewidth = 1, scale_units = 'xy')\n",
    "    plt.scatter(X[i].cpu()[:, 0], X[i].cpu()[:, 1], s = 1, c = 'r', alpha = 0.1, rasterized = True)\n",
    "    with torch.no_grad():\n",
    "        plt.scatter(xs_t_pfi[i, :, 0].cpu(), xs_t_pfi[i, :, 1].cpu(), c = 'blue', alpha = 0.5, s = 1, rasterized = True)\n",
    "        # plt.plot(xs_t[:, :, 0].cpu(), xs_t[:, :, 1].cpu(), c = 'cyan', alpha = 0.05)\n",
    "    plt.xlim(x_min, x_max); plt.ylim(y_min, y_max)\n",
    "    plt.title(f\"t = {ts[i]:.2f}\")\n",
    "    plt.xlabel(\"x\"); plt.ylabel(\"x\");\n",
    "    plt.axis(\"off\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/twowell_PFI_dim_{dim}_beta_{betamax}.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e5f46be-d0ae-49b9-8a8a-fa2e4dfcac43",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "v_upfi.cpu()\n",
    "plt.figure(figsize = (10, 5))\n",
    "for i in range(len(X)):\n",
    "    x, y = torch.meshgrid(torch.linspace(x_min, x_max, 15), torch.linspace(y_min, y_max, 15))\n",
    "    _x = torch.vstack([x.flatten(), y.flatten()]).T\n",
    "    t = torch.scalar_tensor(ts[i])\n",
    "    plt.subplot(2, 5, i+1)\n",
    "    with torch.no_grad():\n",
    "        _g = v_upfi.g_net(_, torch.tensor(_x @ P.T, dtype = torch.float32))\n",
    "        _v = v_upfi.v_net(t, torch.tensor(_x @ P.T, dtype = torch.float32)) @ P\n",
    "    # plt.contourf(x, y, _g.reshape(x.shape), levels = 20, cmap = \"bone\");\n",
    "    # plt.colorbar()\n",
    "    plt.quiver(_x[:, 0], _x[:, 1], _v[:, 0], _v[:, 1], torch.clamp_max(torch.linalg.norm(_v, dim = 1), 10), cmap='Blues', linewidth = 1, scale_units = 'xy')\n",
    "    plt.scatter(X[i].cpu()[:, 0], X[i].cpu()[:, 1], s = 5, c = 'r', alpha = 0.1, rasterized = True)\n",
    "    with torch.no_grad():\n",
    "        plt.scatter(xs_t_upfi[i, :, 1], xs_t_upfi[i, :, 2], s = 250*xs_t_upfi[i, :, 0].exp(), c = 'blue', alpha = 0.5, rasterized = True)\n",
    "        # plt.plot(X_batch_t[:, :, 1], X_batch_t[:, :, 2], c = 'cyan', alpha = 0.05)\n",
    "    plt.xlim(x_min, x_max); plt.ylim(y_min, y_max)\n",
    "    plt.title(f\"t = {ts[i]:.2f}\")\n",
    "    plt.xlabel(\"x\"); plt.ylabel(\"x\")\n",
    "    plt.axis(\"off\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/twowell_UPFI_dim_{dim}_beta_{betamax}.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2f78027-be7e-4cfb-96a7-529918a3dbad",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ratios_true = [sum(x[:, 0] > 0).item() / x.shape[0] for x in X]\n",
    "plt.figure(figsize = (7.5, 2.5))\n",
    "plt.subplot(1, 3, 1)\n",
    "ratios = [xs_t_pfi[i, xs_t_pfi[i, :, 0] > 0, :].shape[0] / xs_t_pfi[i, ...].shape[0] for i in range(len(X))]\n",
    "plt.plot(ts.cpu(), ratios_true, label = \"True\")\n",
    "plt.plot(ts.cpu(), ratios, label = \"Fit\")\n",
    "plt.title(\"PFI\"); plt.legend(); \n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$r_t$\"); plt.ylim(0.4, 1)\n",
    "\n",
    "plt.subplot(1, 3, 2)\n",
    "ratios = [xs_t_upfi[i, :, 0][xs_t_upfi[i, :, 1] > 0].exp().sum().item() / xs_t_upfi[i, :, 0].exp().sum().item() for i in range(len(X))]\n",
    "plt.plot(ts.cpu(), ratios_true, label = \"True\")\n",
    "plt.plot(ts.cpu(), ratios, label = \"Fit\")\n",
    "plt.title(\"UPFI\"); plt.legend(); \n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$r_t$\"); plt.ylim(0.4, 1)\n",
    "\n",
    "plt.subplot(1, 3, 3)\n",
    "ratios = [xs_t_ode[i, :, 0][xs_t_ode[i, :, 1] > 0].exp().sum().item() / xs_t_ode[i, :, 0].exp().sum().item() for i in range(len(X))]\n",
    "plt.plot(ts.cpu(), ratios_true, label = \"True\")\n",
    "plt.plot(ts.cpu(), ratios, label = \"Fit\");\n",
    "plt.title(\"ODE\"); plt.legend();  \n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$r_t$\"); plt.ylim(0.4, 1)\n",
    "\n",
    "# plt.suptitle(\"Branch mass ratio $r_t = m^+_t / m^-_t$\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/twowell_massratios_dim_{dim}_beta_{betamax}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6efd06c9-59c2-4a48-96b8-03d41364dfdd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (5, 2.5))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(ts.cpu(), [x.shape[0] / X[0].shape[0] for x in X], label = \"True\")\n",
    "plt.plot(ts.cpu(), [xs_t_upfi[i, :, 0].exp().sum().item() for i in range(xs_t_upfi.shape[0])], label = \"Fit\")\n",
    "plt.title(\"UPFI\")\n",
    "plt.legend(); plt.xlabel(\"t\"); plt.ylabel(\"$m_t$\")\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(ts.cpu(), [x.shape[0] / X[0].shape[0] for x in X], label = \"True\")\n",
    "plt.plot(ts.cpu(), [xs_t_ode[i, :, 0].exp().sum().item() for i in range(xs_t_ode.shape[0])], label = \"Fit\")\n",
    "plt.title(\"ODE\")\n",
    "plt.legend(); plt.xlabel(\"t\"); plt.ylabel(\"$m_t$\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/twowell_totalmass_dim_{dim}_beta_{betamax}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d765ee0-fe79-492c-b626-24b4b9d09722",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (5, 2.5))\n",
    "plt.subplot(1, 2, 1)\n",
    "beta = lambda x, t: betamax*((np.tanh(2*x[0]) + 1)/2)\n",
    "plt.contourf(x, y, np.vstack([beta([_x, _y], 0) for (_x, _y) in zip(x, y)]), cmap = \"YlGnBu\", vmin = 0, vmax = 5, levels = 20)\n",
    "plt.colorbar()\n",
    "for i in range(len(X)):\n",
    "    plt.scatter(X[i].cpu()[:, 0], X[i].cpu()[:, 1], s = 1, c = 'k', alpha = 0.05, rasterized = True)\n",
    "plt.title(\"$g(x)$ : True\")\n",
    "plt.xlabel('$x_0$'); plt.ylabel('$x_1$')\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.contourf(x, y, _g.reshape(x.shape), levels = 20, cmap = \"YlGnBu\", vmin = 0, vmax = 5);\n",
    "plt.colorbar()\n",
    "for i in range(len(X)):\n",
    "    plt.scatter(X[i].cpu()[:, 0], X[i].cpu()[:, 1], s = 1, c = 'k', alpha = 0.05, rasterized = True)\n",
    "plt.title(\"$g(x)$ : UPFI\")\n",
    "plt.xlabel('$x_0$'); plt.ylabel('$x_1$')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/twowell_growth_dim_{dim}_beta_{betamax}.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a80bfe26-7462-49b9-91f2-05d1f0f39949",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "v_upfi.to(device)\n",
    "plt.figure(figsize = (8, 3))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], c = beta(data['x'].T, None), cmap = \"Oranges\", vmin = -5, vmax = 5, alpha = 0.5)\n",
    "plt.colorbar()\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], c = v_upfi.g_net(_, torch.tensor(data['x'], dtype = torch.float32).to(device)).flatten().detach().cpu(), vmin = 0, vmax = 5, cmap = \"Oranges\", alpha = 0.5)\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84167805-b1cb-4cab-9ea5-ae56b43b9d34",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import ot \n",
    "import DeepRUOT.models\n",
    "f_net = DeepRUOT.models.FNet(in_out_dim=dim, hidden_dim=128, n_hiddens=3, activation='leakyrelu').to(device)\n",
    "sf2m_score_model=DeepRUOT.models.scoreNet2(in_out_dim=dim, hidden_dim=128,  activation='leakyrelu').float().to(device)\n",
    "f_net.load_state_dict(torch.load(f\"deepRUOT/model_result_dim_{dim}_beta_{betamax}\"))\n",
    "sf2m_score_model.load_state_dict(torch.load(f\"deepRUOT/score_model_result_dim_{dim}_beta_{betamax}\"))\n",
    "\n",
    "class _SDE(torch.nn.Module):\n",
    "    noise_type = \"diagonal\"\n",
    "    sde_type = \"ito\"\n",
    "    def __init__(self, ode_drift, growth, score, sigma=1.0):\n",
    "        super().__init__()\n",
    "        self.drift = ode_drift\n",
    "        self.growth = growth\n",
    "        self.score = score\n",
    "        self.sigma = sigma\n",
    "    def f(self, t, x):\n",
    "        y = x[..., 1:]\n",
    "        drift=self.drift(t,y)\n",
    "        growth=self.growth(t,y)\n",
    "        num = y.shape[0]\n",
    "        t = t.expand(num, 1) \n",
    "        return torch.hstack([growth, drift+self.score.compute_gradient(t,y)])\n",
    "    def g(self, t, x):\n",
    "        y = x[..., 1:]\n",
    "        return utils.pad_zeros_upfi(torch.ones_like(y))*self.sigma \n",
    "\n",
    "sde_deepruot = _SDE(f_net.v_net, f_net.g_net, sf2m_score_model, sigma=(0.25 / (T-1))**0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e582d836-d548-4aea-93ab-b77a9e15304b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "f_net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faecbe8b-7d9e-4401-bf1b-a35b27a4be96",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Marginal interpolation\n",
    "v_pfi.to(device); v_upfi.to(device); v_ode.to(device); v_tigon.to(device)\n",
    "x0 = X[0][range(100), :]\n",
    "x0_mass = utils.pad_zeros_upfi(x0)\n",
    "sde_pfi = models.SDE(lambda t, x: v_pfi(t, x), sigma = D**0.5)\n",
    "sde_upfi = models.SDE(lambda t, x: v_upfi(t, x), sigma = torch.cat([torch.tensor([0, ]), torch.full((dim, ), D**0.5)]).to(device))\n",
    "with torch.no_grad():\n",
    "    xs_t_pfi = torchsde.sdeint(sde_pfi, x0, ts, method = \"euler\").cpu()\n",
    "    xs_t_upfi = torchsde.sdeint(sde_upfi, x0_mass, ts, method = \"euler\").cpu()\n",
    "    xs_t_tigon = odeint(v_tigon, x0_mass, ts, **odeint_options).cpu()\n",
    "    xs_t_ode = odeint(v_ode, x0_mass, ts, **odeint_options).cpu()\n",
    "xs_t_deepruot = torchsde.sdeint(\n",
    "        sde_deepruot,\n",
    "        x0_mass.to(device),\n",
    "        dt=0.01*(T-1),\n",
    "        ts=ts.to(device)*(T-1),\n",
    "    ).detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b9f33bd-cad6-42b9-a848-58bf9bb60e44",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import evals\n",
    "import seaborn as sb\n",
    "importlib.reload(evals)\n",
    "# Energy distance on path space \n",
    "plt.figure(figsize = (5, 3))\n",
    "sb.barplot({ \n",
    "            \"UPFI\" : evals.energy_distance_paths(xs_t_upfi[..., 1:].permute((1, 0, 2)).numpy(), data['x_paths']), \n",
    "            \"PFI\" : evals.energy_distance_paths(xs_t_pfi.permute((1, 0, 2)).numpy(), data['x_paths']),\n",
    "            \"ODE\" : evals.energy_distance_paths(xs_t_ode[..., 1:].permute((1, 0, 2)).numpy(), data['x_paths']),\n",
    "            \"TIGON\" : evals.energy_distance_paths(xs_t_tigon[..., 1:].permute((1, 0, 2)).numpy(), data['x_paths']),\n",
    "            \"DeepRUOT\" : evals.energy_distance_paths(xs_t_deepruot[..., 1:].permute((1, 0, 2)).numpy(), data['x_paths'])\n",
    "           }, palette = \"tab10\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72eb3a50-11e4-4bc8-b95a-1f78719497f0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Try forward simulation\n",
    "v_pfi.to(device); v_upfi.to(device); v_ode.to(device)\n",
    "_ts = torch.linspace(0, 1, 100)\n",
    "sde_pfi = models.SDE(lambda t, x: v_pfi(t, x), sigma = D**0.5)\n",
    "sde_upfi = models.SDE(lambda t, x: v_upfi(t, x), sigma = torch.cat([torch.tensor([0, ]), torch.full((dim, ), D**0.5)]).to(device))\n",
    "with torch.no_grad():\n",
    "    xs_t_pfi = torchsde.sdeint(sde_pfi, x0, _ts, method = \"euler\").cpu()\n",
    "    xs_t_upfi = torchsde.sdeint(sde_upfi, x0_mass, _ts, method = \"euler\").cpu()\n",
    "    xs_t_tigon = odeint(v_tigon, x0_mass, _ts, **odeint_options).cpu()\n",
    "    xs_t_ode = odeint(v_ode, x0_mass, _ts, **odeint_options).cpu()\n",
    "xs_t_deepruot = torchsde.sdeint(\n",
    "        sde_deepruot,\n",
    "        x0_mass.to(device),\n",
    "        dt=0.01*(T-1),\n",
    "        ts=_ts.to(device)*(T-1),\n",
    "    ).detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ddfbd6b-30fe-49c5-b3cd-a219f2a386dc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.patheffects as pe\n",
    "plt.figure(figsize = (10.0, 2.5))\n",
    "_data = torch.load(\"sim_twowell_N_500_T_100_dim_10_D_0.25_beta_0.0.pkl\")\n",
    "plt.subplot(1, 6, 1)\n",
    "plt.plot(_ts, _data['x_paths'][range(100), ..., 0].T, alpha = 0.1, color = 'blue');\n",
    "plt.scatter(ts[data['t_idx']].cpu() + np.random.randn(len(data['t_idx']))*0.01, data['x'][:, 0], alpha = 0.1, s = 1, color = \"red\", rasterized = True)\n",
    "r = (1.0*(_data['x_paths'][:, -1, 1] < 0)).mean()\n",
    "plt.annotate(f\"{1-r:.2f}\", (0.8, 1.0), path_effects=[pe.withStroke(linewidth=2, foreground=\"white\")])\n",
    "plt.annotate(f\"{r:.2f}\", (0.8, -1.0), path_effects=[pe.withStroke(linewidth=2, foreground=\"white\")])\n",
    "plt.ylim(-2, 2)\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x_0$\"); plt.axis(\"off\")\n",
    "plt.title(\"True\")\n",
    "plt.subplot(1, 6, 3)\n",
    "plt.plot(_ts, xs_t_pfi[..., 0], alpha = 0.1, color = 'blue');\n",
    "plt.scatter(ts[data['t_idx']].cpu() + np.random.randn(len(data['t_idx']))*0.01, data['x'][:, 0], alpha = 0.1, s = 1, color = \"red\", rasterized = True)\n",
    "r = (1.0*(xs_t_pfi[-1, :, 1] < 0)).mean()\n",
    "plt.annotate(f\"{1-r:.2f}\", (0.8, 1.0), path_effects=[pe.withStroke(linewidth=2, foreground=\"white\")])\n",
    "plt.annotate(f\"{r:.2f}\", (0.8, -1.0), path_effects=[pe.withStroke(linewidth=2, foreground=\"white\")])\n",
    "plt.ylim(-2, 2)\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x_0$\"); plt.axis(\"off\")\n",
    "plt.title(\"PFI\")\n",
    "plt.subplot(1, 6, 2)\n",
    "plt.plot(_ts, xs_t_upfi[..., 1:][..., 0], alpha = 0.1, color = 'blue');\n",
    "plt.scatter(ts[data['t_idx']].cpu() + np.random.randn(len(data['t_idx']))*0.01, data['x'][:, 0], alpha = 0.1, s = 1, color = \"red\", rasterized = True)\n",
    "r = (1.0*(xs_t_upfi[-1, :, 1] < 0)).mean()\n",
    "plt.annotate(f\"{1-r:.2f}\", (0.8, 1.0), path_effects=[pe.withStroke(linewidth=2, foreground=\"white\")])\n",
    "plt.annotate(f\"{r:.2f}\", (0.8, -1.0), path_effects=[pe.withStroke(linewidth=2, foreground=\"white\")])\n",
    "plt.ylim(-2, 2)\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x_0$\"); plt.axis(\"off\")\n",
    "plt.title(\"UPFI\")\n",
    "plt.tight_layout()\n",
    "plt.subplot(1, 6, 4)\n",
    "plt.plot(_ts, xs_t_ode[..., 1:][..., 0], alpha = 0.1, color = 'blue');\n",
    "plt.scatter(ts[data['t_idx']].cpu() + np.random.randn(len(data['t_idx']))*0.01, data['x'][:, 0], alpha = 0.1, s = 1, color = \"red\", rasterized = True)\n",
    "r = (1.0*(xs_t_ode[-1, :, 1] < 0)).mean()\n",
    "plt.annotate(f\"{1-r:.2f}\", (0.8, 1.0), path_effects=[pe.withStroke(linewidth=2, foreground=\"white\")])\n",
    "plt.annotate(f\"{r:.2f}\", (0.8, -1.0), path_effects=[pe.withStroke(linewidth=2, foreground=\"white\")])\n",
    "plt.ylim(-2, 2)\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x_0$\"); plt.axis(\"off\")\n",
    "plt.title(\"ODE\")\n",
    "plt.subplot(1, 6, 6)\n",
    "plt.plot(_ts, xs_t_deepruot[..., 1:][..., 0], alpha = 0.1, color = 'blue');\n",
    "plt.scatter(ts[data['t_idx']].cpu() + np.random.randn(len(data['t_idx']))*0.01, data['x'][:, 0], alpha = 0.1, s = 1, color = \"red\", rasterized = True)\n",
    "r = (1.0*(xs_t_deepruot[-1, :, 1] < 0)).mean()\n",
    "plt.annotate(f\"{1-r:.2f}\", (0.8, 1.0), path_effects=[pe.withStroke(linewidth=2, foreground=\"white\")])\n",
    "plt.annotate(f\"{r:.2f}\", (0.8, -1.0), path_effects=[pe.withStroke(linewidth=2, foreground=\"white\")])\n",
    "plt.ylim(-2, 2)\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x_0$\"); plt.axis(\"off\")\n",
    "plt.title(\"DeepRUOT\")\n",
    "plt.subplot(1, 6, 5)\n",
    "plt.plot(_ts, xs_t_tigon[..., 1:][..., 0], alpha = 0.1, color = 'blue');\n",
    "plt.scatter(ts[data['t_idx']].cpu() + np.random.randn(len(data['t_idx']))*0.01, data['x'][:, 0], alpha = 0.1, s = 1, color = \"red\", rasterized = True)\n",
    "r = (1.0*(xs_t_tigon[-1, :, 1] < 0)).mean()\n",
    "plt.annotate(f\"{1-r:.2f}\", (0.8, 1.0), path_effects=[pe.withStroke(linewidth=2, foreground=\"white\")])\n",
    "plt.annotate(f\"{r:.2f}\", (0.8, -1.0), path_effects=[pe.withStroke(linewidth=2, foreground=\"white\")])\n",
    "plt.ylim(-2, 2)\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x_0$\"); plt.axis(\"off\")\n",
    "plt.title(\"TIGON\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/twowell_paths_dim_{dim}_beta_{betamax}.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "897826e7-616a-4311-ac5e-e69e8cb6072f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "xs_t_deepruot[..., 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "455945d5-ec12-4b0b-a112-eed1d1e9bb0f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (10, 2.5))\n",
    "plt.subplot(1, 4, 1)\n",
    "plt.scatter(torch.tile(_ts[:, None], (1, xs_t_upfi.shape[1])), xs_t_upfi[..., 1:][..., 0], s = xs_t_upfi[..., 0].exp()*2.5, c = xs_t_upfi[..., 0], alpha = 0.1, cmap = 'viridis', facecolors = 'none', rasterized = True)\n",
    "cb=plt.colorbar(); cb.solids.set(alpha=1)\n",
    "plt.ylim(-2, 2); plt.xlabel(\"$t$\"); plt.ylabel(\"$x_0$\"); plt.title(\"UPFI\")\n",
    "plt.subplot(1, 4, 2)\n",
    "plt.scatter(torch.tile(_ts[:, None], (1, xs_t_ode.shape[1])), xs_t_ode[..., 1:][..., 0], s = xs_t_ode[..., 0].exp()*2.5, c = xs_t_ode[..., 0], alpha = 0.1, cmap = 'viridis', facecolors = 'none', rasterized = True)\n",
    "plt.ylim(-2, 2); plt.xlabel(\"$t$\"); plt.ylabel(\"$x_0$\"); plt.title(\"ODE\")\n",
    "cb=plt.colorbar(); cb.solids.set(alpha=1)\n",
    "plt.subplot(1, 4, 3)\n",
    "plt.scatter(torch.tile(_ts[:, None], (1, xs_t_deepruot.shape[1])), xs_t_deepruot[..., 1:][..., 0], s = xs_t_deepruot[..., 0].exp()*2.5, c = xs_t_deepruot[..., 0], alpha = 0.1, cmap = 'viridis', facecolors = 'none', rasterized = True)\n",
    "plt.ylim(-2, 2); plt.xlabel(\"$t$\"); plt.ylabel(\"$x_0$\"); plt.title(\"DeepRUOT\")\n",
    "cb=plt.colorbar(); cb.solids.set(alpha=1)\n",
    "plt.subplot(1, 4, 4)\n",
    "plt.scatter(torch.tile(_ts[:, None], (1, xs_t_tigon.shape[1])), xs_t_tigon[..., 1:][..., 0], s = xs_t_tigon[..., 0].exp()*2.5, c = xs_t_tigon[..., 0], alpha = 0.1, cmap = 'viridis', facecolors = 'none', rasterized = True)\n",
    "plt.ylim(-2, 2); plt.xlabel(\"$t$\"); plt.ylabel(\"$x_0$\"); plt.title(\"TIGON\")\n",
    "cb=plt.colorbar(); cb.solids.set(alpha=1)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/twowell_paths_masses_dim_{dim}_beta_{betamax}.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83cc1df0-a94b-4718-84d0-f9465b53bf4a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sf2m_score_model.cpu()\n",
    "plt.figure(figsize = (12, 2))\n",
    "for (i, t) in enumerate(ts):\n",
    "    plt.subplot(1, T, i+1)\n",
    "    x, y = torch.meshgrid(torch.linspace(x_min, x_max, 15), torch.linspace(y_min, y_max, 15))\n",
    "    _x = torch.vstack([x.flatten(), y.flatten()]).T\n",
    "    with torch.no_grad():\n",
    "        _g = sf2m_score_model(torch.full((_x.shape[0], 1), i), torch.tensor(_x @ P.T, dtype = torch.float32))\n",
    "    plt.contourf(x, y, _g.reshape(x.shape), levels = 15, cmap = \"RdBu_r\",);\n",
    "    plt.colorbar()\n",
    "    plt.scatter(data['x'][data['t_idx'] == i, 0], data['x'][data['t_idx'] == i, 1], s = 1, alpha = 0.1, rasterized = True, c = 'k')\n",
    "    plt.xlabel('$x_0$'); plt.ylabel('$x_1$')\n",
    "    plt.title(f\"t = {t:.2f}\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/twowell_DeepRUOT_score_dim_{dim}_beta_{betamax}.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "302e39f8-59da-4e0d-b97b-bc3715a7f0d2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "v_ode.cpu()\n",
    "plt.figure(figsize = (12, 2))\n",
    "for (i, t) in enumerate(ts):\n",
    "    plt.subplot(1, 5, i+1)\n",
    "    with torch.no_grad():\n",
    "        x, y = torch.meshgrid(torch.linspace(x_min, x_max, 15), torch.linspace(y_min, y_max, 15))\n",
    "        _x = torch.vstack([x.flatten(), y.flatten()]).T\n",
    "        _g = v_ode.F_net(torch.scalar_tensor(t), torch.tensor(_x @ P.T, dtype = torch.float32))\n",
    "    plt.contourf(x, y, _g.reshape(x.shape), levels = 15, cmap = \"RdBu_r\");\n",
    "    plt.colorbar()\n",
    "    plt.scatter(X[i].cpu()[:, 0], X[i].cpu()[:, 1], s = 1, c = 'k', alpha = 0.1)\n",
    "    plt.xlabel('$x_0$'); plt.ylabel('$x_1$')\n",
    "    plt.title(f\"t = {t:.2f}\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/twowell_ODE_fitness_masses_dim_{dim}_beta_{betamax}.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ef616e0-2e5b-4cf2-a430-96acbcf6ca01",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gen_data\n",
    "import pandas as pd\n",
    "import seaborn as sb\n",
    "import utils\n",
    "importlib.reload(utils)\n",
    "importlib.reload(gen_data)\n",
    "\n",
    "# Get ground truth vector field \n",
    "Psi_gt, Psi_thresh = gen_data.Psi(data['x'], None, dim = dim), 1.0\n",
    "vf_gt = -torch.tensor(autograd.elementwise_grad(lambda x,t : gen_data.Psi(x, t, dim = dim))(data['x'], None), dtype = torch.float32)\n",
    "g_gt = gen_data.beta(data['x'].T, None, betamax)\n",
    "\n",
    "v_pfi.to(device); v_upfi.to(device); v_ode.to(device); v_tigon.to(device)\n",
    "with torch.no_grad():\n",
    "    # Get inferred vector field\n",
    "    vf_pfi = torch.vstack([v_pfi(t, x) for (t, x) in zip(ts, X)]).cpu()\n",
    "    vf_deepruot = torch.vstack([(T-1)*f_net.v_net(t*(T-1), x) for (t, x) in zip(ts, X)]).cpu()\n",
    "    vf_upfi = torch.vstack([v_upfi.v_net(t, x) for (t, x) in zip(ts, X)]).cpu()\n",
    "    vf_tigon = torch.vstack([v_tigon.v_net(t, x) for (t, x) in zip(ts, X)]).cpu()\n",
    "    vf_ode = torch.vstack([v_ode.dF(t, x) for (t, x) in zip(ts, X)]).cpu()\n",
    "    # Get inferred growth rate\n",
    "    g_upfi = torch.vstack([v_upfi.g_net(t, x) for (t, x) in zip(ts, X)]).cpu()\n",
    "    g_tigon = torch.vstack([v_tigon.g_net(t, x) for (t, x) in zip(ts, X)]).cpu()\n",
    "    g_ode = torch.vstack([v_ode.F_net(t, x) for (t, x) in zip(ts, X)]).cpu()\n",
    "    g_deepruot = torch.vstack([(T-1)*f_net.g_net(t*(T-1), x) for (t, x) in zip(ts, X)]).cpu()\n",
    "\n",
    "df = pd.DataFrame({\"Psi\" : Psi_gt,})\n",
    "# for _dist, what in [(lambda x, y: utils.cos_dist(x[:, range(2)], y[:, range(2)]), \"cos\"), (lambda u, v: (u-v)[:, range(2)].norm(2, 1), \"l2\")]:\n",
    "for _dist, what in [(lambda x, y: utils.cos_dist(x, y), \"cos\"), (lambda u, v: (u-v).norm(2, 1), \"l2\")]:\n",
    "    df.loc[:, f\"dv_PFI_{what}\"] = _dist(vf_gt, vf_pfi).numpy()\n",
    "    df.loc[:, f\"dv_DeepRUOT_{what}\"] = _dist(vf_gt, vf_deepruot).numpy()\n",
    "    df.loc[:, f\"dv_UPFI_{what}\"]= _dist(vf_gt, vf_upfi).numpy()\n",
    "    df.loc[:, f\"dv_TIGON_{what}\"]= _dist(vf_gt, vf_tigon).numpy()\n",
    "    df.loc[:, f\"dv_ODE_{what}\"] = _dist(vf_gt, vf_ode).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96232f63-68ab-4892-bc68-4c1cbd5a1288",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (3, 3))\n",
    "sb.barplot({\"UPFI\" : np.corrcoef(g_upfi.T, g_gt.T)[0, 1],\n",
    "            \"TIGON\" : np.corrcoef(g_tigon.T, g_gt.T)[0, 1],\n",
    "            \"DeepRUOT\" : np.corrcoef(g_deepruot.T, g_gt.T)[0, 1],\n",
    "            \"ODE\" : np.corrcoef(g_ode.T, g_gt.T)[0, 1]})\n",
    "plt.title(\"Growth rates\"); plt.ylabel(\"Correlation\")\n",
    "plt.savefig(f\"../../figures/twowell_growth_correlation_dim_{dim}_beta_{betamax}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "352c9c13-e33f-4da7-bc2b-029732c1d53e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize = (5, 3))\n",
    "_df = df[df.Psi > Psi_thresh].loc[:, [\"dv_UPFI_cos\", \"dv_DeepRUOT_cos\", \"dv_ODE_cos\", \"dv_PFI_cos\", \"dv_TIGON_cos\"]]\n",
    "_df.columns=_df.columns.str.split('_').str[1]\n",
    "sb.boxplot(_df)\n",
    "plt.title(\"Vector field\"); plt.ylabel(\"Cosine distance\")\n",
    "plt.savefig(f\"../../figures/twowell_vf_cosine_dim_{dim}_beta_{betamax}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3818b34-c92c-45e3-b541-3357c9527c92",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (5, 3))\n",
    "_df = df[df.Psi > Psi_thresh].loc[:, [\"dv_UPFI_l2\", \"dv_DeepRUOT_l2\", \"dv_ODE_l2\", \"dv_PFI_l2\", \"dv_TIGON_l2\"]]\n",
    "_df.columns=_df.columns.str.split('_').str[1]\n",
    "sb.boxplot(_df)\n",
    "plt.title(\"Vector field\"); plt.ylabel(\"L2 distance\")\n",
    "plt.savefig(f\"../../figures/twowell_vf_l2_dim_{dim}_beta_{betamax}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ada99b8a-99ca-4508-a41a-a4846fe3e36b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotting\n",
    "importlib.reload(plotting)\n",
    "\n",
    "vmin = Psi_thresh\n",
    "vmax = np.quantile(Psi_gt, 0.95)\n",
    "\n",
    "plt.figure(figsize = (12, 2.5))\n",
    "plt.subplot(1, 6, 1)\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], alpha = 0.25, s = 30, c = Psi_gt, vmin = vmin, vmax = vmax, rasterized = True, cmap = \"GnBu\")\n",
    "plotting.plot_stream_vectorfield(data['x'], vf_gt, ax = plt.gca(), color = 'k'); plt.title(\"$v_{\\\\rm{true}}$\")\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 6, 2)\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], alpha = 0.25, s = 30, c = Psi_gt, vmin = vmin, vmax = vmax, rasterized = True, cmap = \"GnBu\")\n",
    "plotting.plot_stream_vectorfield(data['x'], vf_upfi, ax = plt.gca(), color = 'k'); plt.title(\"$v_{\\\\rm{UPFI}}$\")\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 6, 6)\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], alpha = 0.25, s = 30, c = Psi_gt, vmin = vmin, vmax = vmax, rasterized = True, cmap = \"GnBu\")\n",
    "plotting.plot_stream_vectorfield(data['x'], vf_deepruot, ax = plt.gca(), color = 'k'); plt.title(\"$v_{\\\\rm{DeepRUOT}}$\")\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 6, 4)\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], alpha = 0.25, s = 30, c = Psi_gt, vmin = vmin, vmax = vmax, rasterized = True, cmap = \"GnBu\")\n",
    "plotting.plot_stream_vectorfield(data['x'], vf_ode, ax = plt.gca(), color = 'k'); plt.title(\"$v_{\\\\rm{ODE}}$\")\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 6, 3)\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], alpha = 0.25, s = 30, c = Psi_gt, vmin = vmin, vmax = vmax, rasterized = True, cmap = \"GnBu\")\n",
    "plotting.plot_stream_vectorfield(data['x'], vf_pfi, ax = plt.gca(), color = 'k'); plt.title(\"$v_{\\\\rm{PFI}}$\")\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 6, 5)\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], alpha = 0.25, s = 30, c = Psi_gt, vmin = vmin, vmax = vmax, rasterized = True, cmap = \"GnBu\")\n",
    "plotting.plot_stream_vectorfield(data['x'], vf_tigon, ax = plt.gca(), color = 'k'); plt.title(\"$v_{\\\\rm{TIGON}}$\")\n",
    "plt.xlabel(\"$x_0$\"); plt.ylabel(\"$x_1$\")\n",
    "plt.axis('off')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/twowell_vf_stream_dim_{dim}_beta_{betamax}.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "785b54de-da7b-4e76-9d0e-0f07655c8763",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (2.5, 2.5))\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], alpha = 0.25, s = 30, c = Psi_gt, vmin = vmin, vmax = vmax, rasterized = True)\n",
    "plt.scatter(data['centroids'][:, 0], data['centroids'][:, 1], marker = \"x\", color = \"r\", s = 100)\n",
    "plt.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6783051f-4649-4023-a4ba-59d7303483ce",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "importlib.reload(evals)\n",
    "sde_deepruot.to(device)\n",
    "probs_upfi, probs_pfi, probs_ode, probs_deepruot, probs_tigon = [], [], [], [], []\n",
    "_centroids = torch.tensor(data['centroids'], dtype = torch.float32)\n",
    "for i in range(T):\n",
    "    probs_upfi.append(evals.get_centroid_probs(X[i],\n",
    "                                               lambda x: torchsde.sdeint(sde_upfi, utils.pad_zeros_upfi(x), ts[[i, -1]], method = \"euler\")[-1, ..., 1:].cpu() if i < T-1 else x.cpu(), \n",
    "                                               _centroids, \n",
    "                                              ))\n",
    "    probs_pfi.append(evals.get_centroid_probs(X[i], \n",
    "                                             lambda x: torchsde.sdeint(sde_pfi, x, ts[[i, -1]], method = \"euler\")[-1, ...].cpu() if i < T-1 else x.cpu(), \n",
    "                                             _centroids))\n",
    "    probs_deepruot.append(evals.get_centroid_probs(X[i], \n",
    "                                               lambda x: torchsde.sdeint(sde_deepruot, utils.pad_zeros_upfi(x), ts[[i, -1]]*(T-1), dt = 0.01*(T-1), method = \"euler\")[-1, ..., 1:].cpu() if i < T-1 else x.cpu(), \n",
    "                                            _centroids))\n",
    "    probs_ode.append(evals.get_centroid_probs(X[i], \n",
    "                                               lambda x: odeint(v_ode, utils.pad_zeros_upfi(x), ts[[i, -1]], **odeint_options)[-1, ..., 1:].cpu() if i < T-1 else x.cpu(), \n",
    "                                            _centroids, n_sample = 1))\n",
    "    probs_tigon.append(evals.get_centroid_probs(X[i], \n",
    "                                               lambda x: odeint(v_tigon, utils.pad_zeros_upfi(x), ts[[i, -1]], **odeint_options)[-1, ..., 1:].cpu() if i < T-1 else x.cpu(), \n",
    "                                            _centroids, n_sample = 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0289da3a-2bd8-495a-92a2-b47be85b4b93",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (12.5, 2.5))\n",
    "plt.subplot(1, 6, 1)\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], c = data['probs'][:, 0], cmap = \"RdBu_r\", alpha = 0.1, s = 30, rasterized = True, vmin = -0.15, vmax = 1.15); plt.title(\"True\")\n",
    "plotting.plot_stream_vectorfield(data['x'], vf_gt, ax = plt.gca(), color = 'k'); plt.title(\"$v_{\\\\rm{true}}$\")\n",
    "plt.axis(\"off\")\n",
    "plt.subplot(1, 6, 2)\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], c = torch.vstack(probs_upfi)[:, 0], cmap = \"RdBu_r\", alpha = 0.1, s = 30, rasterized = True, vmin = -0.15, vmax = 1.15); plt.title(\"UPFI\")\n",
    "plotting.plot_stream_vectorfield(data['x'], vf_upfi, ax = plt.gca(), color = 'k'); plt.title(\"$v_{\\\\rm{UPFI}}$\")\n",
    "plt.axis(\"off\")\n",
    "plt.subplot(1, 6, 3)\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], c = torch.vstack(probs_deepruot)[:, 0], cmap = \"RdBu_r\", alpha = 0.1, s = 30, rasterized = True, vmin = -0.15, vmax = 1.15); plt.title(\"DeepRUOT\")\n",
    "plotting.plot_stream_vectorfield(data['x'], vf_deepruot, ax = plt.gca(), color = 'k'); plt.title(\"$v_{\\\\rm{DeepRUOT}}$\")\n",
    "plt.axis(\"off\")\n",
    "plt.subplot(1, 6, 4)\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], c = torch.vstack(probs_ode)[:, 0], cmap = \"RdBu_r\", alpha = 0.1, s = 30, rasterized = True, vmin = -0.15, vmax = 1.15); plt.title(\"ODE\")\n",
    "plotting.plot_stream_vectorfield(data['x'], vf_ode, ax = plt.gca(), color = 'k'); plt.title(\"$v_{\\\\rm{ODE}}$\")\n",
    "plt.axis(\"off\")\n",
    "plt.subplot(1, 6, 5)\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], c = torch.vstack(probs_pfi)[:, 0], cmap = \"RdBu_r\", alpha = 0.1, s = 30, rasterized = True, vmin = -0.15, vmax = 1.15); plt.title(\"PFI\")\n",
    "plotting.plot_stream_vectorfield(data['x'], vf_pfi, ax = plt.gca(), color = 'k'); plt.title(\"$v_{\\\\rm{PFI}}$\")\n",
    "plt.axis(\"off\")\n",
    "plt.subplot(1, 6, 6)\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], c = torch.vstack(probs_tigon)[:, 0], cmap = \"RdBu_r\", alpha = 0.1, s = 30, rasterized = True, vmin = -0.15, vmax = 1.15); plt.title(\"TIGON\")\n",
    "plotting.plot_stream_vectorfield(data['x'], vf_tigon, ax = plt.gca(), color = 'k'); plt.title(\"$v_{\\\\rm{TIGON}}$\")\n",
    "plt.axis(\"off\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/twowell_probs_dim_{dim}_beta_{betamax}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faa92169-b741-4bb5-be35-be64354d4e77",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import seaborn as sb\n",
    "# plt.figure(figsize = (3, 3))\n",
    "# sb.barplot({\"UPFI\" : (probs_upfi - data['probs']).abs().sum(-1).mean().item(), \n",
    "#             \"PFI\" : (probs_pfi - data['probs']).abs().sum(-1).mean().item(), \n",
    "#            }, palette = \"tab10\")\n",
    "# plt.title(\"Total variation distance: fate probabilities\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b075edc1-3195-4145-a5b2-8ad69f54415f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import glob\n",
    "from toolz import interleave\n",
    "import numpy as np\n",
    "files = glob.glob(\"evals/df_err_paths*.csv\")\n",
    "dims = [int(os.path.basename(f).split(\"_\")[8].split(\"dim\")[-1]) for f in files]\n",
    "df = pd.concat([pd.read_csv(f, index_col = 0).T for f in files])\n",
    "df.loc[:, \"dims\"] = dims\n",
    "_df1_mean = df.groupby('dims').agg(['mean', ])\n",
    "_df1_std = df.groupby('dims').agg(['std', ])\n",
    "_df1_mean_str = _df1_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df1_std_str = _df1_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df1_str = pd.DataFrame({_df1_mean_str.columns[i] : _df1_mean_str.iloc[:, i].str.cat(_df1_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df1_mean_str.shape[1])})\n",
    "for i, j in enumerate(np.argmin(_df1_mean.values, 1)):\n",
    "    _df1_str.iloc[i, j] = \"\\\\textbf{\" + _df1_str.iloc[i, j] + \"}\"\n",
    "_df1_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45ed0831-5734-4bcb-86e4-6722b7b4fcc6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "files = glob.glob(\"evals/df_vectorfield*.csv\")\n",
    "dims = [int(os.path.basename(f).split(\"_\")[7].split(\"dim\")[-1]) for f in files]\n",
    "df = pd.concat([pd.DataFrame(pd.read_csv(f, index_col = 0).mean(0)[1:]).T for f in files])\n",
    "df = df.loc[:, df.columns.str.contains(\"_cos\")]\n",
    "df.loc[:, \"dims\"] = dims\n",
    "_df2_mean = df.groupby('dims').agg(['mean', ])\n",
    "_df2_std = df.groupby('dims').agg(['std', ])\n",
    "_df2_mean_str = _df2_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df2_std_str = _df2_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df2_str = pd.DataFrame({_df2_mean_str.columns[i] : _df2_mean_str.iloc[:, i].str.cat(_df2_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df2_mean_str.shape[1])})\n",
    "for i, j in enumerate(np.argmin(_df2_mean.values, 1)):\n",
    "    _df2_str.iloc[i, j] = \"\\\\textbf{\" + _df2_str.iloc[i, j] + \"}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c354548d-65f9-4dba-b6af-55286b00b000",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "df = pd.concat([pd.DataFrame(pd.read_csv(f, index_col = 0).mean(0)[1:]).T for f in files])\n",
    "df = df.loc[:, df.columns.str.contains(\"_l2\")]\n",
    "df.loc[:, \"dims\"] = dims\n",
    "_df3_mean = df.groupby('dims').agg(['mean', ])\n",
    "_df3_std = df.groupby('dims').agg(['std', ])\n",
    "_df3_mean_str = _df3_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df3_std_str = _df3_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df3_str = pd.DataFrame({_df3_mean_str.columns[i] : _df3_mean_str.iloc[:, i].str.cat(_df3_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df3_mean_str.shape[1])})\n",
    "for i, j in enumerate(np.argmin(_df3_mean.values, 1)):\n",
    "    _df3_str.iloc[i, j] = \"\\\\textbf{\" + _df3_str.iloc[i, j] + \"}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f9177ee-ca85-4041-b657-3eb14b41f8d6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(_df1_str.iloc[:, ].to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa8ce449-b6b8-44fb-b87c-ab90aa268b9b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(_df2_str.iloc[:, [2, 0, 4, 3, 1]].to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29f4f377-c3a2-4cd8-ae41-f5388d576c13",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(_df3_str.iloc[:, [2, 0, 4, 3, 1]].to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "129c2852-7d0e-4c07-b1a4-939dfba94875",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "files = glob.glob(\"evals/df_fate_corr*.csv\")\n",
    "dims = [int(os.path.basename(f).split(\"_\")[8].split(\"dim\")[-1]) for f in files]\n",
    "df = pd.concat([pd.DataFrame(pd.read_csv(f, index_col = 0)).assign(dim = dim) for (f, dim) in zip(files, dims)])\n",
    "df = df.loc[df.what == 'pearson']\n",
    "_df_mean = df.groupby(['dim', 'what']).agg(['mean', ])\n",
    "_df_std = df.groupby(['dim', 'what']).agg(['std', ])\n",
    "_df_mean_str = _df_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df_std_str = _df_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df_str = pd.DataFrame({_df_mean_str.columns[i] : _df_mean_str.iloc[:, i].str.cat(_df_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df_mean_str.shape[1])})\n",
    "for i, j in enumerate(np.argmax(_df_mean.values, 1)):\n",
    "    _df_str.iloc[i, j] = \"\\\\textbf{\" + _df_str.iloc[i, j] + \"}\"\n",
    "print(_df_str.reset_index().drop(columns = 'what').set_index('dim').iloc[:, [0, 1, 3, 4, 2]].to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bca3ada-7c5d-4df5-99d1-58268df25547",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_df_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96719c75-f4ca-459d-819e-73c57ffa8a7d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe6d751f-0fc3-4dd7-ae81-44daca017968",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1152aaf8-e352-4ced-a076-ffc7379aa6ab",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "upfi",
   "language": "python",
   "name": "upfi"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
