{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24200ca3-803f-44a5-986f-c1651466b817",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "import models\n",
    "from torch import optim\n",
    "import torchdiffeq\n",
    "from torchdiffeq import odeint\n",
    "import geomloss\n",
    "from tqdm import tqdm\n",
    "device = torch.device('cuda:0')\n",
    "\n",
    "import sklearn as sk\n",
    "import sklearn.decomposition\n",
    "import sklearn.preprocessing\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import importlib\n",
    "importlib.reload(models)\n",
    "\n",
    "import scanpy as sc\n",
    "import anndata as ad\n",
    "\n",
    "np.random.seed(0)\n",
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0daf34b9-a1e6-48f5-8d8f-53f83ebbaeb2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data = torch.load(\"data_pca.pkl\", weights_only = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e0cc415-55c9-46da-bc74-bda0a59238f1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data['x'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bae5ba87-bf1d-4a90-8646-b498ec3913b2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (2.5, 2.5))\n",
    "plt.scatter(data['x_spring'][:, 0], data['x_spring'][:, 1], s = 0.1, c = data['t_idx'], alpha = 0.25, cmap = 'viridis')\n",
    "plt.axis('off')\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "832b9b8a-ebe5-4b1c-9a3f-b54e6eea1d11",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "color_dict = {\"Erythroid\" : 'grey',\n",
    " \"Monocyte\" : '#2320ba', \n",
    " \"Neutrophil\" : '#a81616', \n",
    " \"Undifferentiated\" : 'grey'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06918582-1aed-4ad1-b04c-955e2fbcdcf1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (7.5, 2.5))\n",
    "for i in range(3):\n",
    "    plt.subplot(1, 3, i+1)\n",
    "    idx = data['t_idx'] == i\n",
    "    plt.scatter(data['x_spring'][idx, 0], data['x_spring'][idx, 1], c = [color_dict[x] for x in data['celltype'][idx]], s = 0.1, alpha = 0.25, rasterized = True)\n",
    "    plt.xlim(-1000, 3500)\n",
    "    plt.ylim(-2500, 1500)\n",
    "    plt.axis('off')\n",
    "plt.savefig(\"../../figures/LARRY_spring_vs_time.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8131a84-efcd-4565-8953-fd4963749bca",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import importlib\n",
    "import models\n",
    "device = torch.device('cuda:0')\n",
    "importlib.reload(models)\n",
    "\n",
    "seed = 2\n",
    "T = 3\n",
    "D = 0.25\n",
    "dim = 10\n",
    "ts = torch.linspace(0, 1, T)\n",
    "X = [data['x'][data['t_idx'] == i].to(device) for i in range(T)]\n",
    "\n",
    "s = models.NCScoreFunc(d = dim, hidden_sizes = [128, 128, 128], activation = torch.nn.ReLU, time_dependent = True).to(device)\n",
    "s.load_state_dict(torch.load(f\"weights/params_NCScoreFunc_default_additive_pcadim_{dim}_seed_{seed}_final.pt\"))\n",
    "sigmas = torch.linspace(0, -2, 5, device = device).exp()\n",
    "\n",
    "s = s.to(device)\n",
    "hidden_sizes = [128, 128, 128]\n",
    "\n",
    "v_upfi = models.ODEFlowGrowth(d = dim, kwargs_v = {'hidden_sizes' : hidden_sizes, 'time_dependent' : False}, \n",
    "                                       kwargs_g = {'hidden_sizes' : hidden_sizes, 'time_dependent' : False}).to(device)\n",
    "v_upfi.load_state_dict(torch.load(f'weights/params_UPFI_ODEFlowGrowth_default_additive_pcadim_{dim}_seed_{seed}_final.pt'))\n",
    "v_pfi = models.VectorField(d = dim, hidden_sizes = hidden_sizes, time_dependent = True).to(device)\n",
    "v_pfi.load_state_dict(torch.load(f'weights/params_PFI_VectorField_default_additive_pcadim_{dim}_seed_{seed}_final.pt'))\n",
    "v_ode = models.ODEFlowGrowthCoupled(d = dim, hidden_sizes = hidden_sizes, time_dependent = True).to(device)\n",
    "v_ode.load_state_dict(torch.load(f'weights/params_ODE_ODEFlowGrowthCoupled_default_additive_pcadim_{dim}_seed_{seed}_final.pt'))\n",
    "v_tigon = models.ODEFlowGrowth(d = dim, kwargs_v = {'hidden_sizes' : hidden_sizes, 'time_dependent' : True}, \n",
    "                                       kwargs_g = {'hidden_sizes' : hidden_sizes, 'time_dependent' : True}).to(device)\n",
    "v_tigon.load_state_dict(torch.load(f'weights/params_TIGON_ODEFlowGrowth_default_additive_pcadim_{dim}_seed_{seed}_final.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0438224-bcbe-483f-be09-fa9393df9a1b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import seaborn as sb\n",
    "plt.figure(figsize = (2.5, 2.5))\n",
    "_df = pd.DataFrame([data['celltype'][data['t_idx'] == i].value_counts()[:-1] for i in range(T)])\n",
    "_df = pd.DataFrame(_df.values / _df.values.sum(1)[:, None], columns = _df.columns).reset_index()\n",
    "_df = _df.rename(columns = {'index' : 't'})\n",
    "_df = _df.melt(id_vars = ['t'], value_vars= ['Undifferentiated', 'Monocyte', 'Neutrophil'])\n",
    "g = sb.barplot(_df, y = \"value\", hue = \"Cell type annotation\", x = \"t\", palette = color_dict)\n",
    "plt.ylabel(\"Mass\"); plt.xlabel(\"$t$\"); plt.xticks(range(T), ts.numpy())\n",
    "sb.move_legend(g, \"upper right\", bbox_to_anchor=(2., 1))\n",
    "plt.savefig(\"../../figures/larry_relative_masses.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9ede186-351b-4048-bf6c-5430a967a67d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "s.to(device)\n",
    "samplers = [models.LangevinSampler(lambda x, sigma, _t = torch.scalar_tensor(_s).to(device): s(_t, x, sigma), \n",
    "                      torch.randn(1_000, dim).to(device), \n",
    "                      sigmas = sigmas, dt = 3e-3, n_iter = 5000) for _s in ts]\n",
    "x_sample = torch.vstack([s.sample().cpu()[None, ...] for s in samplers])\n",
    "X_all = torch.vstack(X).cpu()\n",
    "s.cpu()\n",
    "x_min, x_max = (X_all[:, 0]).min()-1, (X_all[:, 0]).max()+1\n",
    "y_min, y_max = (X_all[:, 1]).min()-1, (X_all[:, 1]).max()+1\n",
    "\n",
    "plt.figure(figsize = (10, 2.5))\n",
    "for i in range(len(X)):\n",
    "    plt.subplot(1,5, i+1)\n",
    "    plt.scatter(X[i].cpu()[:, 0], X[i].cpu()[:, 1], s = 1, c = 'r', alpha = 0.01, rasterized = True, label = \"Data\")\n",
    "    plt.scatter(x_sample[i, :, 0].cpu(), x_sample[i, :, 1], s = 5, c = 'b', alpha = 0.1, rasterized = True, label = \"Sampled\")\n",
    "    plt.xlim(x_min, x_max); plt.ylim(y_min, y_max)\n",
    "    plt.axis(\"off\")\n",
    "    plt.title(f\"t = {ts[i]:.2f}\")\n",
    "    plt.xlabel(\"x\"); plt.ylabel(\"x\")\n",
    "    if i == 0:\n",
    "        leg=plt.legend()\n",
    "        for lh in leg.legendHandles: \n",
    "            lh.set_alpha(1)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/larry_score_validation.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16365be8-2e6b-47a5-8ddd-7a37916e1e9e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import utils\n",
    "import seaborn as sb\n",
    "_v_upfi = v_upfi.v_net(_, data['x'].to(device)).detach().cpu()\n",
    "_v_pfi = torch.vstack([v_pfi(ts[i], data['x'][data['t_idx'] == i, :].to(device)) for i in range(T)]).detach().cpu()\n",
    "_v_tigon = torch.vstack([v_tigon.v_net(t, x) for (t, x) in zip(ts, X)]).detach().cpu()\n",
    "_v_ode = torch.vstack([v_ode.dF(t, x) for (t, x) in zip(ts, X)]).detach().cpu()\n",
    "_g_upfi = torch.vstack([v_upfi.g_net(ts[i], data['x'][data['t_idx'] == i, :].to(device)) for i in range(T)]).flatten().detach().cpu()\n",
    "_g_ode = torch.vstack([v_ode.F_net(ts[i], data['x'][data['t_idx'] == i, :].to(device)) for i in range(T)]).flatten().detach().cpu()\n",
    "_g_tigon = torch.vstack([v_tigon.g_net(ts[i], data['x'][data['t_idx'] == i, :].to(device)) for i in range(T)]).flatten().detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a21b2140-3755-4878-b4a3-81d488807f7e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sb.barplot({\"UPFI\" : utils.cos_dist(data['v'], _v_upfi), \n",
    "            \"PFI\" : utils.cos_dist(data['v'], _v_pfi),\n",
    "            \"ODE\" : utils.cos_dist(data['v'], _v_ode),\n",
    "            \"TIGON\" : utils.cos_dist(data['v'], _v_tigon)\n",
    "           })\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a792570-26d4-4225-ace9-bad5cea40ad5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (7.5, 2.5))\n",
    "k=0\n",
    "plt.subplot(1, 3, 1)\n",
    "plt.scatter(data['x_spring'][:, k], data['x_spring'][:, k+1], alpha = 0.5, c = _g_upfi, cmap = \"RdBu_r\", vmin = -5, vmax = 5, s = 2, rasterized = True)\n",
    "cb=plt.colorbar(); cb.set_alpha(1); cb.draw_all()\n",
    "plt.title(\"UPFI\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 3, 2)\n",
    "plt.scatter(data['x_spring'][:, k], data['x_spring'][:, k+1], alpha = 0.5, c = _g_ode, cmap = \"RdBu_r\", vmin = -5, vmax = 5, s = 2, rasterized = True)\n",
    "cb=plt.colorbar(); cb.set_alpha(1); cb.draw_all()\n",
    "plt.axis('off')\n",
    "plt.title(\"ODE\")\n",
    "plt.subplot(1, 3, 3)\n",
    "plt.scatter(data['x_spring'][:, k], data['x_spring'][:, k+1], alpha = 0.5, c = _g_tigon, cmap = \"RdBu_r\", vmin = -5, vmax = 5, s = 2, rasterized = True)\n",
    "cb=plt.colorbar(); cb.set_alpha(1); cb.draw_all()\n",
    "plt.axis('off')\n",
    "plt.title(\"TIGON\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/larry_growth_rates.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fe3e116-9a98-4cbf-bafa-2e82c415f7f7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_dfs = []\n",
    "for c in [\"Undifferentiated\", \"Monocyte\", \"Neutrophil\"]:\n",
    "    for t in range(T):\n",
    "        idx = (data[\"celltype\"] == c) & (data['t_idx'] == t)\n",
    "        _dfs.append(pd.DataFrame(\n",
    "            {\"UPFI\" : utils.cos_dist(data['v'][idx], _v_upfi[idx]),\n",
    "             \"PFI\" : utils.cos_dist(data['v'][idx], _v_pfi[idx]), \n",
    "             \"ODE\" : utils.cos_dist(data['v'][idx], _v_ode[idx]), \n",
    "             \"TIGON\" : utils.cos_dist(data['v'][idx], _v_tigon[idx]), \n",
    "         \"celltype\" : c,\n",
    "         \"t\" : t}))\n",
    "_df = pd.concat(_dfs).melt(id_vars = ['t', 'celltype'], value_vars=['UPFI', 'PFI', \"ODE\", \"TIGON\"])\n",
    "g = sb.FacetGrid(_df, col=\"t\")\n",
    "g.map(sb.barplot, 'celltype', 'value', 'variable', palette=\"tab10\")\n",
    "g.add_legend(title=\"variable\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07751dcb-ee4a-4f21-a1e2-f8dd5e29db35",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "idx = data['celltype'] != 'Undifferentiated'\n",
    "pd.DataFrame({'upfi' : utils.cos_dist(data['v'][idx], _v_upfi[idx]), \n",
    "'pfi' : utils.cos_dist(data['v'][idx], _v_pfi[idx]), \n",
    "'ode' : utils.cos_dist(data['v'][idx], _v_ode[idx]), \n",
    "'tigon' : utils.cos_dist(data['v'][idx], _v_tigon[idx])}).mean(0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0e35575-d275-4e2d-ab23-23422eff3c60",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_df = pd.concat([pd.read_csv(f\"df_rnavelo_cos_mean_seed_{s+1}.csv\", header = 0, index_col = 0).assign(seed=s) for s in range(5)]).reset_index().pivot(index = ['t', 'celltype', 'seed'], columns = 'variable').reset_index()\n",
    "_df_mean = _df.groupby(['t', 'celltype']).mean().reindex(columns=[\"UPFI\", \"PFI\", \"ODE\", \"TIGON\"], level = \"variable\")\n",
    "_df_std = _df.groupby(['t', 'celltype']).std().reindex(columns=[\"UPFI\", \"PFI\", \"ODE\", \"TIGON\"], level = \"variable\")\n",
    "_df_mean_str = _df_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df_std_str = _df_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df_str = pd.DataFrame({_df_mean_str.columns[i] : _df_mean_str.iloc[:, i].str.cat(_df_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df_mean_str.shape[1])})\n",
    "for i, j in enumerate(np.argmin(_df_mean.values, 1)):\n",
    "    _df_str.iloc[i, j] = \"\\\\textbf{\" + _df_str.iloc[i, j] + \"}\"\n",
    "# _df_str.drop(columns='seed')\n",
    "# _df_str=_df_str.reset_index().drop(columns=['seed']).set_index('celltype')\n",
    "_df_str=_df_str.reset_index()\n",
    "print(_df_str.to_latex())\n",
    "# _df_str=_df_str.set_index(_df_str.iloc[:, 1].str.cat(_df_str.iloc[:, 0].astype('str'), sep = \", \"))\n",
    "# print(_df_str.drop(columns=['t', 'celltype']).to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50410f9b-7492-4991-badb-96c951f2900f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (15, 3))\n",
    "plt.subplot(1, 4, 1)\n",
    "z = utils.cos_dist(data['v'], _v_upfi)\n",
    "plt.scatter(data['x_spring'][:, 0], data['x_spring'][:, 1], alpha = 0.3, c = z, cmap = \"RdBu_r\", vmin = 0, vmax = 0.5, s = 2, rasterized = True)\n",
    "plt.axis('off')\n",
    "plt.title(\"UPFI\")\n",
    "cb=plt.colorbar(); cb.set_alpha(1); cb.draw_all()\n",
    "plt.subplot(1, 4, 2)\n",
    "z = utils.cos_dist(data['v'], _v_pfi)\n",
    "plt.scatter(data['x_spring'][:, 0], data['x_spring'][:, 1], alpha = 0.3, c = z, cmap = \"RdBu_r\", vmin = 0, vmax = 0.5, s = 2, rasterized = True)\n",
    "plt.axis('off')\n",
    "plt.title(\"PFI\")\n",
    "cb=plt.colorbar(); cb.set_alpha(1); cb.draw_all()\n",
    "plt.subplot(1, 4, 3)\n",
    "z = utils.cos_dist(data['v'], _v_ode)\n",
    "plt.scatter(data['x_spring'][:, 0], data['x_spring'][:, 1], alpha = 0.3, c = z, cmap = \"RdBu_r\", vmin = 0, vmax = 0.5, s = 2, rasterized = True)\n",
    "plt.axis('off')\n",
    "plt.title(\"ODE\")\n",
    "cb=plt.colorbar(); cb.set_alpha(1); cb.draw_all()\n",
    "plt.subplot(1, 4, 4)\n",
    "z = utils.cos_dist(data['v'], _v_tigon)\n",
    "plt.scatter(data['x_spring'][:, 0], data['x_spring'][:, 1], alpha = 0.3, c = z, cmap = \"RdBu_r\", vmin = 0, vmax = 0.5, s = 2, rasterized = True)\n",
    "plt.axis('off')\n",
    "plt.title(\"TIGON\")\n",
    "cb=plt.colorbar(); cb.set_alpha(1); cb.draw_all()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67a35c81-026e-4908-b616-d433f7ccdd73",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import plotting\n",
    "importlib.reload(plotting)\n",
    "k=0\n",
    "plt.figure(figsize = (12.5, 2.5))\n",
    "plt.subplot(1, 5, 1)\n",
    "plt.scatter(data['x'][:, k], data['x'][:, k+1], alpha = 0.1, color = [color_dict[x] for x in data['celltype']], s = 0.1, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(data['x'][:, k:k+2], data['v'][:, k:k+2], density = 1.5, color = 'k', lw0 = 0.5)\n",
    "plt.title(\"RNA velocity\"); plt.axis('off')\n",
    "plt.subplot(1, 5, 2)\n",
    "plt.scatter(data['x'][:, k], data['x'][:, k+1], alpha = 0.1, color = [color_dict[x] for x in data['celltype']], s = 0.1, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(data['x'][:, k:k+2], _v_upfi[:, k:k+2], density = 1.5, color = 'k', lw0 = 0.5)\n",
    "plt.title(\"UPFI\"); plt.axis('off')\n",
    "plt.subplot(1, 5, 3)\n",
    "plt.scatter(data['x'][:, k], data['x'][:, k+1], alpha = 0.1, color = [color_dict[x] for x in data['celltype']], s = 0.1, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(data['x'][:, k:k+2], _v_pfi[:, k:k+2], density = 1.5, color = 'k', lw0 = 0.5)\n",
    "plt.title(\"PFI\"); plt.axis('off')\n",
    "plt.subplot(1, 5, 4)\n",
    "plt.scatter(data['x'][:, k], data['x'][:, k+1], alpha = 0.1, color = [color_dict[x] for x in data['celltype']], s = 0.1, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(data['x'][:, k:k+2], _v_ode[:, k:k+2], density = 1.5, color = 'k', lw0 = 0.5)\n",
    "plt.title(\"ODE\"); plt.axis('off')\n",
    "plt.subplot(1, 5, 5)\n",
    "plt.scatter(data['x'][:, k], data['x'][:, k+1], alpha = 0.1, color = [color_dict[x] for x in data['celltype']], s = 0.1, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(data['x'][:, k:k+2], _v_tigon[:, k:k+2], density = 1.5, color = 'k', lw0 = 0.5)\n",
    "plt.title(\"TIGON\"); plt.axis('off')\n",
    "plt.savefig(\"../../figures/LARRY_PCA_force.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8807f73-b916-4bae-9a12-f0b6c155f079",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Now check integration\n",
    "import torchdiffeq\n",
    "import utils\n",
    "odeint_options = {'method' : 'euler', 'options' : {'step_size' : data['t_final'] / (5*T)}}\n",
    "\n",
    "ts = torch.linspace(0, data['t_final'], len(X), device = device)\n",
    "m_ratios = torch.tensor([x.shape[0] / X[0].shape[0] for x in X]).float()\n",
    "v_upfi.to(device); v_pfi.to(device); s.to(device)\n",
    "\n",
    "D = 0.25\n",
    "def F_ode_upfi(t, x):\n",
    "    return v_upfi(t, x) - (D/2)*torch.hstack([torch.zeros_like(x[:, :1]), s(t, x[:, 1:], sigmas[-1]), ])\n",
    "F_ode_pfi = lambda t, x: v_pfi(t, x) - (D/2)*s(t, x, sigmas[-1])\n",
    "\n",
    "X_batch = utils.sample_batch_upfi(X, m_ratios.to(device), 5000, replacement = True).to(device)\n",
    "# X_t_upfi = torch.stack([X_batch[0], ] + [torchdiffeq.odeint(F_ode_upfi, X_batch[i], ts[[i, i+1]], **odeint_options)[-1].detach() for i in range(T-1)])\n",
    "X_t_upfi = torchdiffeq.odeint(F_ode_upfi, X_batch[0], ts, **odeint_options).detach().cpu()\n",
    "# X_t_pfi = torch.stack([X_batch[0][:, 1:], ] + [torchdiffeq.odeint(F_ode_pfi, X_batch[i][:, 1:], ts[[i, i+1]], **odeint_options)[-1].detach() for i in range(T-1)])\n",
    "X_t_pfi = torchdiffeq.odeint(F_ode_pfi, X_batch[0][:, 1:], ts, **odeint_options).detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfb43186-1512-451c-90e2-ab24bcaf2add",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for i in range(T-1):\n",
    "    plt.subplot(1, 2, i+1)\n",
    "    plt.scatter(data['x'][data['t_idx'] == i+1, 0], data['x'][data['t_idx'] == i+1, 1], s = 1, c = 'grey', alpha = 0.1)\n",
    "    plt.scatter(X_t_upfi[i+1, :, 1], X_t_upfi[i+1, :, 2], s = 1)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f76347c0-4a29-4f1b-b492-254eeaef22e1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import evals\n",
    "[evals.sinkhorn_divergence(X_t_upfi[i, :, 1:], X_batch[i][:, 1:].cpu(), x_w = X_t_upfi[i, :, 0].exp()) for i in range(T)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b7cdea7-67b1-4891-8008-20dfb16bdc79",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "[evals.sinkhorn_divergence(X_t_pfi[i, ...], X_batch[i][:, 1:].cpu()) for i in range(T)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b633d7af-795e-4433-8322-af6740013210",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torchsde\n",
    "import utils\n",
    "_ts = torch.linspace(0, 1, 25)\n",
    "x0_mass = utils.sample_batch_upfi(X, m_ratios.to(device), 1024)[0]\n",
    "x0 = x0_mass[..., 1:]\n",
    "sde_pfi = models.SDE(lambda t, x: v_pfi(t, x), sigma = D**0.5)\n",
    "# sde_upfi = models.SDE(lambda t, x: v_upfi.v_net(t, x), sigma = D**0.5)\n",
    "sde_upfi = models.SDE(lambda t, x: v_upfi(t, x), sigma = torch.cat([torch.tensor([0, ]), torch.full((dim, ), D**0.5)]).to(device))\n",
    "with torch.no_grad():\n",
    "    xs_t_pfi = torchsde.sdeint(sde_pfi, x0, _ts, method = \"euler\").cpu()\n",
    "    xs_t_upfi = torchsde.sdeint(sde_upfi, x0_mass, _ts, method = \"euler\").cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b65e27f4-4378-4e2f-9aad-0dcf28266efc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (3, 3))\n",
    "plt.plot(_ts, xs_t_upfi[..., 0].exp().sum(1), label = \"UPFI\")\n",
    "plt.plot(ts.cpu(), m_ratios.cpu(), label = \"True\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ea627e4-8c9d-4d3a-835d-e7e255eeba3c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "[evals.sinkhorn_divergence(xs_t_upfi[i, :, 1:], X_batch[i][:, 1:].cpu(), x_w = xs_t_upfi[i, :, 0].exp()) for i in range(T)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5751a078-a779-4612-a0ea-01688a840060",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "[evals.sinkhorn_divergence(xs_t_pfi[i, :, :], X_batch[i][:, 1:].cpu()) for i in range(T)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d245bf72-7ad0-4696-8e9c-fb6c649c2c57",
   "metadata": {},
   "outputs": [],
   "source": [
    "fate_data = torch.load(f\"fate_probs_seed_{seed}.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ae17296-052a-4ac7-a589-c85cd32a7b71",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (10, 2.5))\n",
    "plt.subplot(1, 4, 1)\n",
    "plt.scatter(data['x_spring'][data['t_idx'] == 0, 0], data['x_spring'][data['t_idx'] == 0, 1], c = fate_data['probs_upfi'][:, 0].cpu(), vmin = 0, vmax = 1, cmap = \"bwr\", alpha = 0.5, s = 0.1, rasterized = True)\n",
    "plt.title(\"UPFI\"); plt.axis('off')\n",
    "plt.subplot(1, 4, 2)\n",
    "plt.scatter(data['x_spring'][data['t_idx'] == 0, 0], data['x_spring'][data['t_idx'] == 0, 1], c = fate_data['probs_pfi'][:, 0].cpu(), vmin = 0, vmax = 1, cmap = \"bwr\", alpha = 0.5, s = 0.1, rasterized = True)\n",
    "plt.title(\"PFI\"); plt.axis('off')\n",
    "plt.subplot(1, 4, 3)\n",
    "plt.scatter(data['x_spring'][data['t_idx'] == 0, 0], data['x_spring'][data['t_idx'] == 0, 1], c = fate_data['probs_ode'][:, 0].cpu(), vmin = 0, vmax = 1, cmap = \"bwr\", alpha = 0.5, s = 0.1, rasterized = True)\n",
    "plt.title(\"ODE\"); plt.axis('off')\n",
    "plt.subplot(1, 4, 4)\n",
    "plt.scatter(data['x_spring'][data['t_idx'] == 0, 0], data['x_spring'][data['t_idx'] == 0, 1], c = fate_data['probs_tigon'][:, 0].cpu(), vmin = 0, vmax = 1, cmap = \"bwr\", alpha = 0.5, s = 0.1, rasterized = True)\n",
    "plt.title(\"TIGON\"); plt.axis('off')\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/LARRY_spring_t0_fate.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9238e865-51c5-4058-b158-8ff83a6ae4cb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "df = pd.read_csv(f\"df_fate_seed_{seed}.csv\", index_col=0, header = 0)\n",
    "plt.figure(figsize = (12.5, 2.5))\n",
    "plt.subplot(1, 5, 1)\n",
    "plt.scatter(df.SPRINGx, df.SPRINGy, c = df.pneu, cmap = \"bwr\", alpha = 0.5, s = 5, rasterized = True)\n",
    "plt.title(\"Lineage\"); plt.axis(\"off\")\n",
    "plt.subplot(1, 5, 2)\n",
    "plt.scatter(df.SPRINGx, df.SPRINGy, c = df.PFI_neu, cmap = \"bwr\", alpha = 0.5, s = 5, rasterized = True)\n",
    "plt.title(\"PFI\"); plt.axis(\"off\")\n",
    "plt.subplot(1, 5, 3)\n",
    "plt.scatter(df.SPRINGx, df.SPRINGy, c = df.UPFI_neu, cmap = \"bwr\", alpha = 0.5, s = 5, rasterized = True)\n",
    "plt.title(\"UPFI\"); plt.axis(\"off\")\n",
    "plt.subplot(1, 5, 4)\n",
    "plt.scatter(df.SPRINGx, df.SPRINGy, c = df.ODE_neu, cmap = \"bwr\", alpha = 0.5, s = 5, rasterized = True)\n",
    "plt.title(\"ODE\"); plt.axis(\"off\")\n",
    "plt.subplot(1, 5, 5)\n",
    "plt.scatter(df.SPRINGx, df.SPRINGy, c = df.TIGON_neu, cmap = \"bwr\", alpha = 0.5, s = 5, rasterized = True)\n",
    "plt.title(\"TIGON\")\n",
    "plt.tight_layout(); plt.axis(\"off\")\n",
    "# plt.savefig(\"../../figures/LARRY_spring_t0_fate_vs_lineage.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ba0c21e-683d-4bb1-8214-e2f9d018ab0b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (10, 2.5))\n",
    "plt.subplot(1, 5, 1)\n",
    "plt.scatter(data['x_spring'][data['t_idx'] == 0, 0], data['x_spring'][data['t_idx'] == 0, 1], c = 'grey', vmin = 0, vmax = 1, cmap = \"bwr\", alpha = 0.25, s = 1, rasterized = True)\n",
    "plt.scatter(df.SPRINGx, df.SPRINGy, c = df.pneu, cmap = \"bwr\", alpha = 0.5, s = 5, rasterized = True)\n",
    "plt.title(\"Lineage\"); plt.axis(\"off\")\n",
    "plt.subplot(1, 5, 2)\n",
    "plt.scatter(data['x_spring'][data['t_idx'] == 0, 0], data['x_spring'][data['t_idx'] == 0, 1], c = fate_data['probs_upfi'][:, 0].cpu(), vmin = 0, vmax = 1, cmap = \"bwr\", alpha = 0.5, s = 0.5, rasterized = True)\n",
    "plt.title(\"UPFI\"); plt.axis('off')\n",
    "plt.subplot(1, 5, 3)\n",
    "plt.scatter(data['x_spring'][data['t_idx'] == 0, 0], data['x_spring'][data['t_idx'] == 0, 1], c = fate_data['probs_pfi'][:, 0].cpu(), vmin = 0, vmax = 1, cmap = \"bwr\", alpha = 0.5, s = 0.5, rasterized = True)\n",
    "plt.title(\"PFI\"); plt.axis('off')\n",
    "plt.subplot(1, 5, 4)\n",
    "plt.scatter(data['x_spring'][data['t_idx'] == 0, 0], data['x_spring'][data['t_idx'] == 0, 1], c = fate_data['probs_ode'][:, 0].cpu(), vmin = 0, vmax = 1, cmap = \"bwr\", alpha = 0.5, s = 0.5, rasterized = True)\n",
    "plt.title(\"ODE\"); plt.axis('off')\n",
    "plt.subplot(1, 5, 5)\n",
    "plt.scatter(data['x_spring'][data['t_idx'] == 0, 0], data['x_spring'][data['t_idx'] == 0, 1], c = fate_data['probs_tigon'][:, 0].cpu(), vmin = 0, vmax = 1, cmap = \"bwr\", alpha = 0.5, s = 0.5, rasterized = True)\n",
    "plt.title(\"TIGON\"); plt.axis('off')\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/LARRY_spring_t0_fate_vs_lineage.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "051ed440-f7d3-4ad8-a03a-0e1fc6102582",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import scipy as sp\n",
    "import pandas as pd\n",
    "_df = pd.concat([pd.read_csv(f\"df_fate_pearsonr_seed_{s+1}.csv\", index_col=5, header = 0).iloc[:, 1:] for s in range(5)], axis = 0)\n",
    "_df = _df.loc[:, [\"UPFI\", \"PFI\", \"ODE\", \"TIGON\"]]\n",
    "_df_mean, _df_std = _df.groupby('what').mean(0), _df.groupby('what').std(0)\n",
    "_df_mean_str = _df_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df_std_str = _df_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df_str = pd.DataFrame({_df_mean_str.columns[i] : _df_mean_str.iloc[:, i].str.cat(_df_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df_mean_str.shape[1])})\n",
    "for i, j in enumerate(np.argmax(_df_mean.values, 1)):\n",
    "    _df_str.iloc[i, j] = \"\\\\textbf{\" + _df_str.iloc[i, j] + \"}\"\n",
    "print(_df_str.to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1c8da7f-40b0-45b9-bea3-33679e436d99",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pd.read_csv(f\"df_fate_pearsonr_seed_{seed+1}.csv\", index_col=5, header = 0).iloc[:, 1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "894c8fd1-496f-4853-9147-40dd64435e35",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "dfs = [pd.read_csv(f\"df_fate_seed_{s+1}.csv\", index_col=0, header = 0) for s in range(5)]\n",
    "\n",
    "hist_kwargs = {'bins' : 8}\n",
    "plt.figure(figsize = (8, 1.5))\n",
    "plt.subplot(1, 5, 1)\n",
    "for df in dfs[:1]:\n",
    "    h, e = np.histogram(df.pneu, **hist_kwargs); h = h / h.sum()\n",
    "    plt.stairs(h, e, label = \"UPFI\", color = 'k')\n",
    "plt.title(\"Lineage\")\n",
    "plt.xlabel('$p_{\\\\rm{Neu}}$'); plt.ylabel(\"Frequency\")\n",
    "plt.subplot(1, 5, 2)\n",
    "for df in dfs:\n",
    "    h, e = np.histogram(df.UPFI_neu, **hist_kwargs);  h = h / h.sum()\n",
    "    plt.stairs(h, e, label = \"UPFI\", color = 'b', alpha = 0.3)\n",
    "plt.title(\"UPFI\")\n",
    "plt.xlabel('$p_{\\\\rm{Neu}}$'); plt.ylabel(\"Frequency\")\n",
    "plt.subplot(1, 5, 3)\n",
    "for df in dfs:\n",
    "    h, e = np.histogram(df.PFI_neu, **hist_kwargs); h = h / h.sum()\n",
    "    plt.stairs(h, e, label = \"PFI\", color = 'r', alpha = 0.3)\n",
    "plt.title(\"PFI\")\n",
    "plt.xlabel('$p_{\\\\rm{Neu}}$'); plt.ylabel(\"Frequency\")\n",
    "plt.subplot(1, 5, 4)\n",
    "for df in dfs:\n",
    "    h, e = np.histogram(df.ODE_neu, **hist_kwargs); h = h / h.sum()\n",
    "    plt.stairs(h, e, label = \"ODE\", color = 'g', alpha = 0.3)\n",
    "plt.title(\"ODE\")\n",
    "plt.xlabel('$p_{\\\\rm{Neu}}$'); plt.ylabel(\"Frequency\")\n",
    "plt.subplot(1, 5, 5)\n",
    "for df in dfs:\n",
    "    h, e = np.histogram(df.TIGON_neu, **hist_kwargs); h = h / h.sum()\n",
    "    plt.stairs(h, e, label = \"TIGON\", color = 'purple', alpha = 0.3)\n",
    "plt.title(\"TIGON\")\n",
    "plt.xlabel('$p_{\\\\rm{Neu}}$'); plt.ylabel(\"Frequency\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/LARRY_fate_probs_hist.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40b8f6df-8a03-4ae0-95eb-44786a3f1e2b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efdbae87-39bb-4876-8b6e-54bf6a86d2bb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1399a2a1-4834-472d-9128-fc0dfcc9a6d5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "upfi",
   "language": "python",
   "name": "upfi"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
