{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a361e05c-61a4-4bed-9ee3-f4d357960434",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "num_threads = \"4\"\n",
    "os.environ[\"OMP_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"OPENBLAS_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"MKL_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"VECLIB_MAXIMUM_THREADS\"] = num_threads\n",
    "os.environ[\"NUMEXPR_NUM_THREADS\"] = num_threads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35b830dc-c74b-4bc7-b471-706fc6af2053",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "import importlib\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import linear_model\n",
    "import dcor\n",
    "import torchsde\n",
    "import pandas as pd\n",
    "from torchdiffeq import odeint\n",
    "import sklearn as sk\n",
    "from sklearn import preprocessing\n",
    "torch.set_default_dtype(torch.float32)\n",
    "\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "\n",
    "# Simulate represillator data\n",
    "import torchsde\n",
    "class Repressilator(torch.nn.Module):\n",
    "    noise_type = 'diagonal'\n",
    "    sde_type = 'ito'\n",
    "    beta = 10\n",
    "    n = 3\n",
    "    k = 1\n",
    "    gamma = 1\n",
    "    sigma = 0.1\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    def f(self, t, x):\n",
    "        x1, x2, x3 = x[:, 0], x[:, 1], x[:, 2]\n",
    "        dx1 = self.beta / (1 + (x3 / self.k)**self.n) - self.gamma*x1\n",
    "        dx2 = self.beta / (1 + (x1 / self.k)**self.n) - self.gamma*x2\n",
    "        dx3 = self.beta / (1 + (x2 / self.k)**self.n) - self.gamma*x3\n",
    "        return torch.vstack([dx1, dx2, dx3]).T\n",
    "    # Diffusion\n",
    "    def g(self, t, x):\n",
    "        return torch.ones_like(x)*self.sigma\n",
    "N = 1000\n",
    "T = 10\n",
    "x0 = torch.hstack([torch.randn(N, 2)*0.1+1, torch.randn(N, 1)*0.1+2])\n",
    "ts = torch.linspace(0, T-1, 10)\n",
    "x = torchsde.sdeint(Repressilator(), x0, ts, dt = 0.05)\n",
    "\n",
    "N_obs = 100\n",
    "Xs_t = torch.cat([x[i, torch.randperm(N)[:2*N_obs], :].unsqueeze(0) for i in range(T)])\n",
    "Xs_t, Xs_t_val = Xs_t[:, :N_obs, :], Xs_t[:, N_obs:, :]\n",
    "\n",
    "scaler_op = sk.preprocessing.StandardScaler(with_mean=True, with_std=False)\n",
    "scaler_op.fit(Xs_t.reshape(-1, Xs_t.shape[-1]))\n",
    "data = {'x' : torch.tensor(scaler_op.transform(Xs_t.reshape(-1, Xs_t.shape[-1])), dtype = torch.float32),\n",
    "        't_idx' : torch.tensor(np.repeat(np.arange(Xs_t.shape[0]), Xs_t.shape[1])), \n",
    "        'x_val' : torch.tensor(scaler_op.transform(Xs_t_val.reshape(-1, Xs_t_val.shape[-1])), dtype = torch.float32),\n",
    "        't_idx_val' : torch.tensor(np.repeat(np.arange(Xs_t_val.shape[0]), Xs_t_val.shape[1])), \n",
    "        'sigma' : Repressilator.sigma\n",
    "       }\n",
    "torch.save(data, \"data.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ed13cc1-d224-4d5e-a82b-dccf69e201e4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import scipy as sp\n",
    "# Get ground truth Jacobian\n",
    "_x = sp.optimize.root(lambda x: Repressilator.beta / (1 + (x/Repressilator.k)**Repressilator.n) - Repressilator.gamma*x, 1.).x[0]\n",
    "x_fixed = torch.full((1, 3), _x)\n",
    "J = torch.func.jacrev(lambda x: Repressilator().f(_, x))(x_fixed).detach().squeeze()\n",
    "D = torch.diag(torch.tensor(scaler_op.scale_)).float() if scaler_op.scale_ is not None else torch.eye(3)\n",
    "J_scaled = D @ J @ torch.linalg.pinv(D)\n",
    "\n",
    "torch.save({'x' : torch.tensor(scaler_op.transform(x_fixed)).flatten().float(), 'J' : J_scaled}, \"gt_ref.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27ae674f-3b42-4205-8956-820e70e379bf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ts_traj_dense = torch.linspace(0, T-1, 100)\n",
    "x_traj_dense = torchsde.sdeint(Repressilator(), x0, ts_traj_dense, dt = 0.05)\n",
    "x_traj = torchsde.sdeint(Repressilator(), x0, ts, dt = 0.05)\n",
    "torch.save({'x_traj' : x_traj, 'ts_traj' : ts, \n",
    "            'x_traj_dense' : x_traj_dense, 'ts_traj_dense' : ts_traj_dense}, \"data_traj.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "047410d5-2eed-4b33-ab27-b5821ec62f25",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ax = plt.figure(figsize = (4, 4)).add_subplot(projection='3d')\n",
    "ax.view_init(elev=30, azim=45, roll=0)\n",
    "ax.scatter(x[..., 0], x[..., 1], x[..., 2], c = ts.repeat_interleave(N), alpha = 0.1, s = 10, rasterized = True);\n",
    "for i in range(10):\n",
    "    ax.plot(x_traj_dense[:, i, 0], x_traj_dense[:, i, 1], x_traj_dense[:, i, 2], c = 'red', alpha = 0.5)\n",
    "plt.axis('off')\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/repressilator_3dplot_traj.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4a4d5b9-acac-4880-9a86-f0c9ffdc442e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "d = data['x'].shape[-1]\n",
    "T = len(np.unique(data['t_idx']))\n",
    "ts = torch.linspace(0, T-1, T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95eb1ff5-3df0-4a24-9911-f8e299ec43dc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import fm\n",
    "from tqdm import tqdm\n",
    "import copy\n",
    "suffix, it = \"leaveout_-1_seed_1\", 4\n",
    "s_model = fm.MLP(d = d, hidden_sizes = [64, 64, 64], time_varying=True, activation = torch.nn.ReLU)\n",
    "v_model = fm.MLP(d = d, hidden_sizes = [64, 64, 64], time_varying=True, activation = torch.nn.ReLU)\n",
    "s_model.load_state_dict(torch.load(f\"weights/otfm_score_iter_{it}_{suffix}.pt\"))\n",
    "v_model.load_state_dict(torch.load(f\"weights/otfm_flow_iter_{it}_{suffix}.pt\"))\n",
    "\n",
    "ref_params = torch.load(f\"weights/reference_iter_{it}_{suffix}.pt\")\n",
    "\n",
    "otfm = fm.LinearEntropicOTFM(data['x'], \n",
    "                      data['t_idx'], \n",
    "                      ts = ts,\n",
    "                      sigma = data[\"sigma\"],\n",
    "                      A = ref_params[\"A\"],\n",
    "                      mu = ref_params[\"mu\"],\n",
    "                      T = T,\n",
    "                      dim = d,\n",
    "                      device = torch.device('cpu')\n",
    "                  )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67ab35ee-4b90-4f5b-8e8e-8b2b2a9fc4d8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (5, 2.5))\n",
    "plt.subplot(1, 2, 1)\n",
    "im=plt.imshow(J.squeeze(), cmap = \"Reds\", vmin = -1.5, vmax = 1.5)\n",
    "plt.colorbar(im,fraction=0.046, pad=0.04)\n",
    "plt.title(\"$\\\\partial_i f_j$\")\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.title(\"Learned $A$\")\n",
    "im=plt.imshow(ref_params['A'], cmap = \"Reds\", vmin = -1.5, vmax = 1.5)\n",
    "plt.colorbar(im,fraction=0.046, pad=0.04)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/repressilator_jac_vs_A.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9c937b6-1b32-4405-99aa-78a117066f64",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class OrnsteinUhlenbeck(torch.nn.Module):\n",
    "    noise_type = 'diagonal'\n",
    "    sde_type = 'ito'\n",
    "    def __init__(self, A, sigma):\n",
    "        super().__init__()\n",
    "        self.A = A\n",
    "        self.sigma = sigma\n",
    "    def f(self, t, x):\n",
    "        return x @ self.A.T\n",
    "    # Diffusion\n",
    "    def g(self, t, x):\n",
    "        return torch.ones_like(x)*self.sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad01ab70-e972-418e-9864-7d9e5e639552",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Try Gaussian approximation\n",
    "Xs = [data['x'][data['t_idx'] == i] for i in range(T)]\n",
    "means = [torch.mean(x, 0) for x in Xs]\n",
    "covs = [torch.cov(x.T) for x in Xs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dda83db0-1e68-4ec6-9c4a-2838aa925e9c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from scipy.stats import multivariate_normal\n",
    "import matplotlib.cm\n",
    "\n",
    "def get_cmap(solid_color):\n",
    "    cdict = {\n",
    "        'red':   [(0.0, 1.0, 1.0),  # Start at white (1.0, 1.0, 1.0)\n",
    "                  (1.0, solid_color[0], solid_color[0])],  # End at solid color\n",
    "\n",
    "        'green': [(0.0, 1.0, 1.0),\n",
    "                  (1.0, solid_color[1], solid_color[1])],\n",
    "\n",
    "        'blue':  [(0.0, 1.0, 1.0),\n",
    "                  (1.0, solid_color[2], solid_color[2])]\n",
    "    }\n",
    "    return matplotlib.colors.LinearSegmentedColormap('WhiteToBlue', cdict)\n",
    "\n",
    "x = np.linspace(-1.75, 2.5, 100)\n",
    "y = np.linspace(-1.75, 2.5, 100)\n",
    "X, Y = np.meshgrid(x, y)\n",
    "_X = np.vstack([X.reshape(-1), Y.reshape(-1)])\n",
    "pos = np.dstack((X, Y))\n",
    "\n",
    "def plot_bivariate(mean, cov, cm):\n",
    "    rv = multivariate_normal(mean, cov)\n",
    "    Z = rv.pdf(pos)\n",
    "    plt.contour(X, Y, Z, levels=5, cmap=cm)\n",
    "    \n",
    "def plot_bivariate_3d(mean, cov, cm):\n",
    "    rv = multivariate_normal(mean, cov)\n",
    "    Z = rv.pdf(pos)\n",
    "    plt.contour(X, Y, Z, levels=5, cmap=cm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5df18243-912b-4058-8f11-945b30e71db7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "importlib.reload(fm)\n",
    "k=0 # dimension to show \n",
    "offset = 2\n",
    "t = torch.linspace(0, 1, 8)\n",
    "for j in range(T-2):\n",
    "    plt.figure(figsize = (7, 3.5))\n",
    "    plt.subplot(1, 2, 1)\n",
    "    _otfm = copy.deepcopy(otfm)\n",
    "    _otfm.A *= offset\n",
    "    _otfm.sigma *= offset**0.5\n",
    "    gsb = fm.GaussianOUSB(_otfm.bm, _otfm)\n",
    "    mean0, mean1 = means[j], means[j+offset]\n",
    "    var0, var1 = covs[j], covs[j+offset]\n",
    "    sb_means, sb_vars, S_t, d_sb_means = gsb.evaluate(t, mean0, mean1, var0, var1)\n",
    "    for i in range(len(t)):\n",
    "        cm = get_cmap(matplotlib.cm.brg(t[i].item()))\n",
    "        plot_bivariate(sb_means[i].flatten()[k:k+2], sb_vars[i].reshape(d, d)[k:k+2, :][:, k:k+2], cm = cm)\n",
    "    idx = (data[\"t_idx\"] == j); plt.scatter(data['x'][idx, k], data['x'][idx, k+1], alpha = 0.5, marker = \"^\", color = 'blue', label = \"$p_{i-1}$\", rasterized = True)\n",
    "    idx = (data[\"t_idx\"] == j+offset//2); plt.scatter(data['x'][idx, k], data['x'][idx, k+1], alpha = 0.5, marker = \"^\", color = 'red', label = \"$p_{i}$\", rasterized = True)\n",
    "    idx = (data[\"t_idx\"] == j+offset); plt.scatter(data['x'][idx, k], data['x'][idx, k+1], alpha = 0.5, marker = \"^\", color = 'green', label = \"$p_{i+1}$\", rasterized = True)\n",
    "    plt.scatter(data['x'][:, k], data['x'][:, k+1], alpha = 0.1, marker = \"o\", color = 'grey', zorder = -100, rasterized = True)\n",
    "    plt.legend(); plt.xlabel(\"$x_1$\"); plt.ylabel(\"$x_2$\")\n",
    "    plt.title(\"mvOU reference\")\n",
    "    plt.subplot(1, 2, 2)\n",
    "    _otfm = copy.deepcopy(otfm)\n",
    "    _otfm.A *= 0\n",
    "    _otfm.sigma *= 2**0.5\n",
    "    gsb = fm.GaussianOUSB(_otfm.bm, _otfm)\n",
    "    mean0, mean1 = means[j], means[j+offset]\n",
    "    var0, var1 = covs[j], covs[j+offset]\n",
    "    sb_means, sb_vars, S_t, d_sb_means = gsb.evaluate(t, mean0, mean1, var0, var1)\n",
    "    for i in range(len(t)):\n",
    "        cm = get_cmap(matplotlib.cm.brg(t[i].item()))\n",
    "        plot_bivariate(sb_means[i].flatten()[k:k+2], sb_vars[i].reshape(d, d)[k:k+2, :][:, k:k+2], cm = cm)\n",
    "    idx = (data[\"t_idx\"] == j); plt.scatter(data['x'][idx, k], data['x'][idx, k+1], alpha = 0.5, marker = \"^\", color = 'blue', label = \"$p_{i-1}$\", rasterized = True)\n",
    "    idx = (data[\"t_idx\"] == j+offset//2); plt.scatter(data['x'][idx, k], data['x'][idx, k+1], alpha = 0.5, marker = \"^\", color = 'red', label = \"$p_{i}$\", rasterized = True)\n",
    "    idx = (data[\"t_idx\"] == j+offset); plt.scatter(data['x'][idx, k], data['x'][idx, k+1], alpha = 0.5, marker = \"^\", color = 'green', label = \"$p_{i+1}$\", rasterized = True)\n",
    "    plt.scatter(data['x'][:, k], data['x'][:, k+1], alpha = 0.1, marker = \"o\", color = 'grey', zorder = -100, rasterized = True)\n",
    "    plt.legend(); plt.xlabel(\"$x_1$\"); plt.ylabel(\"$x_2$\")\n",
    "    plt.title(\"Brownian reference\")\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"../../figures/repressilator_OU_interpolation_{j}.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dfcf258-62be-4d9d-90e5-d0ed34d03e7a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torchsde\n",
    "k=0 # dimension to show \n",
    "offset = 2\n",
    "t = torch.linspace(0, 1, 8)\n",
    "\n",
    "sigma = 0.1\n",
    "sde = fm.SDE(lambda t, x: v_model(t, x) + sigma**2 / 2 * s_model(t, x), sigma)\n",
    "ou_sde = OrnsteinUhlenbeck(A = ref_params['A'], sigma = Repressilator().sigma)\n",
    "\n",
    "for j in range(T-2):\n",
    "    fig = plt.figure(figsize = (12, 3.5))\n",
    "    _ts = torch.linspace(ts[j], ts[j+2], 50)\n",
    "    ax = fig.add_subplot(131, projection='3d')\n",
    "    ax.view_init(elev=30, azim=45, roll=0)\n",
    "    ax.scatter(data['x'][..., 0], data['x'][..., 1], data['x'][..., 2], c = 'k', alpha = 0.2, s = 2.5, marker = 'x', rasterized = True)\n",
    "    idx = (data[\"t_idx\"] == j); ax.scatter(data['x'][idx, 0], data['x'][idx, 1], data['x'][idx, 2], alpha = 0.5, marker = \"^\", color = 'blue', label = \"$p_{i-1}$\", rasterized = True)\n",
    "    idx = (data[\"t_idx\"] == j+offset//2); ax.scatter(data['x'][idx, 0], data['x'][idx, 1], data['x'][idx, 2], alpha = 0.5, marker = \"^\", color = 'red', label = \"$p_{i}$\", rasterized = True)\n",
    "    idx = (data[\"t_idx\"] == j+offset); ax.scatter(data['x'][idx, 0], data['x'][idx, 1], data['x'][idx, 2], alpha = 0.5, marker = \"^\", color = 'green', label = \"$p_{i+1}$\", rasterized = True)\n",
    "    _xs = odeint(v_model, data['x'][data[\"t_idx\"] == j], _ts).detach()\n",
    "    for l in range(25):\n",
    "        ax.plot(_xs[:, l, 0], _xs[:, l, 1], _xs[:, l, 2], alpha = 0.3, color = 'k')\n",
    "    plt.title(\"PF-ODE\")\n",
    "    plt.legend(); plt.xlabel(\"$x_1$\"); plt.ylabel(\"$x_2$\")\n",
    "    ax = fig.add_subplot(132, projection='3d')\n",
    "    ax.view_init(elev=30, azim=45, roll=0)\n",
    "    ax.scatter(data['x'][..., 0], data['x'][..., 1], data['x'][..., 2], c = 'k', alpha = 0.2, s = 2.5, marker = 'x', rasterized = True)\n",
    "    idx = (data[\"t_idx\"] == j); ax.scatter(data['x'][idx, 0], data['x'][idx, 1], data['x'][idx, 2], alpha = 0.5, marker = \"^\", color = 'blue', label = \"$p_{i-1}$\", rasterized = True)\n",
    "    idx = (data[\"t_idx\"] == j+offset//2); ax.scatter(data['x'][idx, 0], data['x'][idx, 1], data['x'][idx, 2], alpha = 0.5, marker = \"^\", color = 'red', label = \"$p_{i}$\", rasterized = True)\n",
    "    idx = (data[\"t_idx\"] == j+offset); ax.scatter(data['x'][idx, 0], data['x'][idx, 1], data['x'][idx, 2], alpha = 0.5, marker = \"^\", color = 'green', label = \"$p_{i+1}$\", rasterized = True)\n",
    "    _xs = torchsde.sdeint(sde, data['x'][data[\"t_idx\"] == j], _ts, dt = 1e-2).detach()\n",
    "    for l in range(25):\n",
    "        ax.plot(_xs[:, l, 0], _xs[:, l, 1], _xs[:, l, 2], alpha = 0.3, color = 'k')\n",
    "    plt.title(\"SDE\")\n",
    "    plt.legend(); plt.xlabel(\"$x_1$\"); plt.ylabel(\"$x_2$\")\n",
    "    plt.legend(); plt.xlabel(\"$x_1$\"); plt.ylabel(\"$x_2$\")\n",
    "    ax = fig.add_subplot(133, projection='3d')\n",
    "    ax.view_init(elev=30, azim=45, roll=0)\n",
    "    ax.scatter(data['x'][..., 0], data['x'][..., 1], data['x'][..., 2], c = 'k', alpha = 0.2, s = 2.5, marker = 'x', rasterized = True)\n",
    "    idx = (data[\"t_idx\"] == j); ax.scatter(data['x'][idx, 0], data['x'][idx, 1], data['x'][idx, 2], alpha = 0.5, marker = \"^\", color = 'blue', label = \"$p_{i-1}$\", rasterized = True)\n",
    "    idx = (data[\"t_idx\"] == j+offset//2); ax.scatter(data['x'][idx, 0], data['x'][idx, 1], data['x'][idx, 2], alpha = 0.5, marker = \"^\", color = 'red', label = \"$p_{i}$\", rasterized = True)\n",
    "    idx = (data[\"t_idx\"] == j+offset); ax.scatter(data['x'][idx, 0], data['x'][idx, 1], data['x'][idx, 2], alpha = 0.5, marker = \"^\", color = 'green', label = \"$p_{i+1}$\", rasterized = True)\n",
    "    _xs = torchsde.sdeint(ou_sde, data['x'][data[\"t_idx\"] == j], _ts, dt = 1e-2).detach()\n",
    "    for l in range(25):\n",
    "        ax.plot(_xs[:, l, 0], _xs[:, l, 1], _xs[:, l, 2], alpha = 0.3, color = 'k')\n",
    "    plt.title(\"Reference\")\n",
    "    plt.legend(); plt.xlabel(\"$x_1$\"); plt.ylabel(\"$x_2$\")\n",
    "    plt.legend(); plt.xlabel(\"$x_1$\"); plt.ylabel(\"$x_2$\")\n",
    "    plt.suptitle(f\"i = {j+1}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74ff4736-fc7a-47dd-af47-455a28d2d4e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check vector field reconstruction \n",
    "v_gt = Repressilator().f(_, torch.tensor(scaler_op.inverse_transform(data['x'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "475c87c4-0224-47fc-a7f9-de9e5a2f2513",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import utils\n",
    "fig = plt.figure(figsize = (12, 4))\n",
    "ax = fig.add_subplot(131, projection='3d')\n",
    "ax.view_init(elev=30, azim=45, roll=0)\n",
    "plt.title(\"$v_{true}$\")\n",
    "ax.scatter(data['x'][..., 0], data['x'][..., 1], data['x'][..., 2], c = data['t_idx'], rasterized = True)\n",
    "ax.quiver(data['x'][..., 0], data['x'][..., 1], data['x'][..., 2], v_gt[..., 0], v_gt[..., 1], v_gt[..., 2], normalize = True, length = 0.5, color = 'k', alpha = 0.25, arrow_length_ratio=0.5, rasterized = True)\n",
    "plt.axis('off')\n",
    "\n",
    "ax = fig.add_subplot(133, projection='3d')\n",
    "it = 0\n",
    "s_model.load_state_dict(torch.load(f\"weights/otfm_score_iter_{it}_{suffix}.pt\")); v_model.load_state_dict(torch.load(f\"weights/otfm_flow_iter_{it}_{suffix}.pt\"))\n",
    "sde = fm.SDE(lambda t, x: v_model(t, x) + data['sigma']**2 / 2 * s_model(t, x), data['sigma'])\n",
    "with torch.no_grad():\n",
    "    v = torch.vstack([sde.f(torch.scalar_tensor(t), data['x'][data['t_idx'] == i, ...]) for i, t in enumerate(ts)])\n",
    "ax.view_init(elev=30, azim=45, roll=0)\n",
    "plt.title(\"$v_{SB}$   (BM reference)\")\n",
    "ax.scatter(data['x'][..., 0], data['x'][..., 1], data['x'][..., 2], c = data['t_idx'], rasterized = True)\n",
    "ax.quiver(data['x'][..., 0], data['x'][..., 1], data['x'][..., 2], v[..., 0], v[..., 1], v[..., 2], normalize = True, length = 0.5, color = 'k', alpha = 0.25, arrow_length_ratio=0.5, rasterized = True)\n",
    "plt.axis('off')\n",
    "\n",
    "print(utils.cos_dist(v_gt.float(), v).mean())\n",
    "\n",
    "ax = fig.add_subplot(132, projection='3d')\n",
    "it = 4\n",
    "s_model.load_state_dict(torch.load(f\"weights/otfm_score_iter_{it}_{suffix}.pt\")); v_model.load_state_dict(torch.load(f\"weights/otfm_flow_iter_{it}_{suffix}.pt\"))\n",
    "sde = fm.SDE(lambda t, x: v_model(t, x) + data['sigma']**2 / 2 * s_model(t, x), data['sigma'])\n",
    "with torch.no_grad():\n",
    "    v = torch.vstack([sde.f(torch.scalar_tensor(t), data['x'][data['t_idx'] == i, ...]) for i, t in enumerate(ts)])\n",
    "ax.view_init(elev=30, azim=45, roll=0)\n",
    "plt.title(\"$v_{SB}$   (mvOU reference)\")\n",
    "ax.scatter(data['x'][..., 0], data['x'][..., 1], data['x'][..., 2], c = data['t_idx'], rasterized = True)\n",
    "ax.quiver(data['x'][..., 0], data['x'][..., 1], data['x'][..., 2], v[..., 0], v[..., 1], v[..., 2], normalize = True, length = 0.5, color = 'k', alpha = 0.25, arrow_length_ratio=0.5, rasterized = True)\n",
    "plt.axis('off')\n",
    "print(utils.cos_dist(v_gt.float(), v).mean())\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../../figures/repressilator_3dplot_vfs.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d952321-4b44-4916-a8ad-a87044de93f4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    v = torch.vstack([ou_sde.f(torch.scalar_tensor(t), data['x'][data['t_idx'] == i, ...]) for i, t in enumerate(ts)])\n",
    "print(utils.cos_dist(v_gt.float(), v).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "848f320f-08da-49bb-93bc-7645d8e0a536",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import seaborn as sb\n",
    "import matplotlib.pyplot as plt\n",
    "suffix = \"leaveout_-1_seed_5\"\n",
    "df = pd.read_csv(f\"evals/marginal_validation_{suffix}.csv\", index_col = 0).iloc[:, 1:]\n",
    "plt.figure(figsize = (5, 3))\n",
    "plt.subplot(1, 2, 1)\n",
    "sb.boxplot(df[df.index == \"energy\"]); plt.xlabel(\"Iterations\"); plt.ylabel(\"Energy distance\")\n",
    "plt.subplot(1, 2, 2)\n",
    "sb.boxplot(df[df.index == \"emd\"]); plt.xlabel(\"Iterations\"); plt.ylabel(\"EMD\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f5a0225-afca-4dd9-b0f8-c579992f10d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "_dfs = []\n",
    "for s in [\"_sbirr_OU\", \"_sbirr_general\"]:\n",
    "    dfs = []\n",
    "    for i in range(1, 9):\n",
    "        df = pd.concat([pd.read_csv(x).rename( columns={'Unnamed: 0':'metric', 'Unnamed: 1' : 't'}) for x in glob.glob(f\"evals/marginal_adjacent_validation{s}_leaveout_{i}*.csv\")])\n",
    "        df = df.melt(id_vars = ['metric', 't'], value_vars = df.columns[2:])\n",
    "        dfs.append(df[df.t == i])\n",
    "    dfs = pd.concat(dfs)\n",
    "    dfs.rename(columns = {'variable' : 'iteration', 'value' : 'error'}, inplace = True)\n",
    "    _dfs.append(dfs)\n",
    "_dfs[0]=_dfs[0].replace({'iteration' : {'sbirr' : 'sbirr_OU'}})\n",
    "_dfs[1]=_dfs[1].replace({'iteration' : {'sbirr' : 'sbirr_general'}})\n",
    "dfs = pd.concat([_dfs[0], _dfs[1].loc[_dfs[1].iteration.str.contains('sbirr'), :]], axis = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7845d84e-77e3-4c4a-9e65-a4b9453ec27d",
   "metadata": {},
   "outputs": [],
   "source": [
    "g = sb.FacetGrid(dfs, col=\"t\",  row=\"metric\", sharey = False)\n",
    "g.map(sb.boxplot, \"iteration\", \"error\", order = [0,1,2,3,4], color = \"lightgrey\")\n",
    "\n",
    "df_gt = pd.concat([pd.read_csv(x).rename( columns={'Unnamed: 0':'metric', 'Unnamed: 1' : 't'}) for x in glob.glob(f\"evals/marginal_adjacent_validation{s}_leaveout_-1_seed_*_gtref.csv\")])\n",
    "df_gt = df_gt.loc[df_gt.t.isin(range(0, 9)), :]\n",
    "df_gt = df_gt.melt(id_vars = ['metric', 't'], value_vars = df_gt.columns[2:])\n",
    "df_gt.rename(columns = {'variable' : 'iteration', 'value' : 'error'}, inplace = True)\n",
    "for row, row_name in enumerate(g.row_names):\n",
    "    for col, col_name in enumerate(g.col_names):\n",
    "        ax = g.axes[row, col]\n",
    "        t = df_gt.loc[(df_gt.metric == row_name) & (df_gt.t == col_name)].error.mean()\n",
    "        ax.axhline(t, color='red', ls='--')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e983cb79-da02-41d0-8239-25d1a0dfcd98",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "_df1_mean = dfs.groupby(['metric' , 't', 'iteration'])[['error']].mean().unstack()\n",
    "_df1_std = dfs.groupby(['metric' ,'t', 'iteration'])[['error']].std().unstack()\n",
    "_df1_mean_str = _df1_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df1_std_str = _df1_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df1_str = pd.DataFrame({_df1_mean_str.columns[i] : _df1_mean_str.iloc[:, i].str.cat(_df1_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df1_mean_str.shape[1])})\n",
    "_df1_str.columns = _df1_mean.columns\n",
    "for i, j in enumerate(np.argmin(_df1_mean.values, 1)):\n",
    "    _df1_str.iloc[i, j] = \"\\\\textbf{\" + _df1_str.iloc[i, j] + \"}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03c95d96-b5e6-40dd-95c3-a47cf4579da0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_df1_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f8b7c89-f660-4d23-a83f-41afca52da4f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(_df1_str.to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12e82472-6aec-468e-80d3-0942dae21738",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_df = pd.DataFrame(dfs.groupby(['metric', 'iteration'])['error'].mean().reset_index())\n",
    "g=sb.FacetGrid(_df, col = \"metric\", sharey = False, height = 2, aspect = 1.25)\n",
    "g.map(sb.barplot, \"iteration\", \"error\", order = [0,1,2,3,4, 'sbirr_OU', 'sbirr_general'], color = \"lightgrey\")\n",
    "plt.suptitle(\"Averaged leave-one-out error\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/repressilator_leave_one_out_avg.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26c10327-3d8f-4980-a002-57d3ebc3ca8f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_df1_mean = dfs.groupby(['metric' ,'iteration'])[['error']].mean().unstack()\n",
    "_df1_std = dfs.groupby(['metric' ,'iteration'])[['error']].std().unstack()\n",
    "_df1_mean_str = _df1_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df1_std_str = _df1_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df1_str = pd.DataFrame({_df1_mean_str.columns[i] : _df1_mean_str.iloc[:, i].str.cat(_df1_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df1_mean_str.shape[1])})\n",
    "_df1_str.columns = _df1_mean.columns\n",
    "for i, j in enumerate(np.argmin(_df1_mean.values, 1)):\n",
    "    _df1_str.iloc[i, j] = \"\\\\textbf{\" + _df1_str.iloc[i, j] + \"}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82b9ae2b-2d81-4076-8839-57cb15959741",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(_df1_str.to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c000897-c06a-4967-b120-78463d1d7b6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "_df1_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b35f7ac-5f04-48af-9153-42577e87251d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2620c716-140e-4e76-a918-b533b2536e75",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lfm",
   "language": "python",
   "name": "lfm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
