{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Single-cell Time Series Interpolation\n",
    "\n",
    "This notebook runs OT-CFM, SB-CFM and SF2M on the embryoid body data, preprocessed according to the notebooks in `preprocessing`. Here we optimize a single network to model cells over time on the 2D PHATE projection. The same process applies for any other representation. In the paper we compare quantitatively on PCA components and highly variable genes.\n",
    "\n",
    "Note that to run this notebook you well need the `ebdata_v3.h5ad` data object, which is accessible at https://data.mendeley.com/datasets/hhny5ff7yj/1 along with the other single cell datasets used in this work.\n",
    "\n",
    "Note that to reproduce the results in the paper, the PC dimensions are taken from this data object https://github.com/KrishnaswamyLab/TrajectoryNet/blob/master/data/eb_velocity_v5.npz, which is not quite the same as the PCs here due to different preprocessing. The highly variable genes were selected with the `sc.pp.highly_variable_genes` function with `n_top_genes=XXX`.\n",
    "\n",
    "To train the model here we build a batch of all timepoint pairs together. This seems to be the most stable if large batches are affordable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import scanpy as sc\n",
    "import torch\n",
    "import torchsde\n",
    "from torchdyn.core import NeuralODE\n",
    "from tqdm import tqdm\n",
    "\n",
    "from torchcfm.conditional_flow_matching import *\n",
    "from torchcfm.models import MLP\n",
    "from torchcfm.utils import plot_trajectories, torch_wrapper\n",
    "\n",
    "savedir = \"models/single-cell\"\n",
    "os.makedirs(savedir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adata = sc.read_h5ad(\"../data/ebdata_v2.h5ad\")\n",
    "adata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.pl.scatter(adata, basis=\"phate\", color=\"sample_labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_times = len(adata.obs[\"sample_labels\"].unique())\n",
    "# Standardize coordinates\n",
    "coords = adata.obsm[\"X_phate\"]\n",
    "coords = (coords - coords.mean(axis=0)) / coords.std(axis=0)\n",
    "adata.obsm[\"X_phate_standardized\"] = coords\n",
    "X = [\n",
    "    adata.obsm[\"X_phate_standardized\"][adata.obs[\"sample_labels\"].cat.codes == t]\n",
    "    for t in range(n_times)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "batch_size = 256\n",
    "sigma = 0.1\n",
    "dim = 2\n",
    "ot_cfm_model = MLP(dim=dim, time_varying=True, w=64).to(device)\n",
    "ot_cfm_optimizer = torch.optim.Adam(ot_cfm_model.parameters(), 1e-4)\n",
    "FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scprep\n",
    "\n",
    "\n",
    "def plot_trajectories(traj, legend=True):\n",
    "    n = 2000\n",
    "    fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "    scprep.plot.scatter(\n",
    "        adata.obsm[\"X_phate_standardized\"][:, 0],\n",
    "        adata.obsm[\"X_phate_standardized\"][:, 1],\n",
    "        c=adata.obs[\"sample_labels\"],\n",
    "        ax=ax,\n",
    "    )\n",
    "    # ax.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c=\"black\")\n",
    "    ax.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.4, alpha=0.1, c=\"olive\")\n",
    "    # ax.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c=\"blue\")\n",
    "\n",
    "    for i in range(15):\n",
    "        ax.plot(traj[:, i, 0], traj[:, i, 1], alpha=0.9, c=\"red\")\n",
    "    if legend:\n",
    "        plt.legend([r\"$p_0$\", r\"$p_t$\", r\"$p_1$\", r\"$X_t \\mid X_0$\"])\n",
    "    # plt.xticks([])\n",
    "    # plt.yticks([])\n",
    "    # plt.axis(\"off\")\n",
    "\n",
    "\n",
    "def get_batch(FM, X, batch_size, n_times, return_noise=False):\n",
    "    \"\"\"Construct a batch with point sfrom each timepoint pair\"\"\"\n",
    "    ts = []\n",
    "    xts = []\n",
    "    uts = []\n",
    "    noises = []\n",
    "    for t_start in range(n_times - 1):\n",
    "        x0 = (\n",
    "            torch.from_numpy(X[t_start][np.random.randint(X[t_start].shape[0], size=batch_size)])\n",
    "            .float()\n",
    "            .to(device)\n",
    "        )\n",
    "        x1 = (\n",
    "            torch.from_numpy(\n",
    "                X[t_start + 1][np.random.randint(X[t_start + 1].shape[0], size=batch_size)]\n",
    "            )\n",
    "            .float()\n",
    "            .to(device)\n",
    "        )\n",
    "        if return_noise:\n",
    "            t, xt, ut, eps = FM.sample_location_and_conditional_flow(\n",
    "                x0, x1, return_noise=return_noise\n",
    "            )\n",
    "            noises.append(eps)\n",
    "        else:\n",
    "            t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1, return_noise=return_noise)\n",
    "        ts.append(t + t_start)\n",
    "        xts.append(xt)\n",
    "        uts.append(ut)\n",
    "    t = torch.cat(ts)\n",
    "    xt = torch.cat(xts)\n",
    "    ut = torch.cat(uts)\n",
    "    if return_noise:\n",
    "        noises = torch.cat(noises)\n",
    "        return t, xt, ut, noises\n",
    "    return t, xt, ut"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## OT-CFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in tqdm(range(10000)):\n",
    "    ot_cfm_optimizer.zero_grad()\n",
    "    t, xt, ut = get_batch(FM, X, batch_size, n_times)\n",
    "    vt = ot_cfm_model(torch.cat([xt, t[:, None]], dim=-1))\n",
    "    loss = torch.mean((vt - ut) ** 2)\n",
    "    loss.backward()\n",
    "    ot_cfm_optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node = NeuralODE(torch_wrapper(ot_cfm_model), solver=\"dopri5\", sensitivity=\"adjoint\")\n",
    "with torch.no_grad():\n",
    "    traj = node.trajectory(\n",
    "        torch.from_numpy(X[0][:1000]).float().to(device),\n",
    "        t_span=torch.linspace(0, n_times - 1, 400),\n",
    "    ).cpu()\n",
    "    plot_trajectories(traj.cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(\n",
    "    {\n",
    "        \"model\": ot_cfm_model,\n",
    "        \"optimizer\": ot_cfm_optimizer,\n",
    "    },\n",
    "    f\"{savedir}/ot_cfm_single_cell.pt\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SF2M\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma = 0.25\n",
    "sf2m_model = MLP(dim=dim, time_varying=True, w=64).to(device)\n",
    "sf2m_score_model = MLP(dim=dim, time_varying=True, w=64).to(device)\n",
    "sf2m_optimizer = torch.optim.AdamW(\n",
    "    list(sf2m_model.parameters()) + list(sf2m_score_model.parameters()), 1e-4\n",
    ")\n",
    "SF2M = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_norm_ut = torch.tensor(0.0)\n",
    "for i in tqdm(range(10000)):\n",
    "    sf2m_optimizer.zero_grad()\n",
    "    t, xt, ut, eps = get_batch(SF2M, X, batch_size, n_times, return_noise=True)\n",
    "    lambda_t = SF2M.compute_lambda(t % 1)\n",
    "    vt = sf2m_model(torch.cat([xt, t[:, None]], dim=-1))\n",
    "    st = sf2m_score_model(torch.cat([xt, t[:, None]], dim=-1))\n",
    "    flow_loss = torch.mean((vt - ut) ** 2)\n",
    "    # max_norm_ut = torch.maximum(torch.max(torch.sum(ut**2, dim=1)), max_norm_ut)\n",
    "    score_loss = torch.mean((lambda_t[:, None] * st + eps) ** 2)\n",
    "    if i % 1000 == 0:\n",
    "        # print(max_norm_ut)\n",
    "        print(f\"{i}: {flow_loss.item():0.2f}, {score_loss.item():0.2f}\")\n",
    "    loss = flow_loss + score_loss\n",
    "\n",
    "    loss.backward()\n",
    "    sf2m_optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node = NeuralODE(torch_wrapper(sf2m_model), solver=\"euler\", sensitivity=\"adjoint\")\n",
    "x0 = torch.from_numpy(X[0][:1000]).float()\n",
    "with torch.no_grad():\n",
    "    traj = node.trajectory(\n",
    "        x0.to(device),\n",
    "        t_span=torch.linspace(0, n_times - 1, 400, device=device),\n",
    "    ).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_trajectories(traj.cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SDE(torch.nn.Module):\n",
    "    noise_type = \"diagonal\"\n",
    "    sde_type = \"ito\"\n",
    "\n",
    "    def __init__(self, ode_drift, score, input_size=(3, 32, 32), sigma=1.0):\n",
    "        super().__init__()\n",
    "        self.drift = ode_drift\n",
    "        self.score = score\n",
    "        self.input_size = input_size\n",
    "        self.sigma = sigma\n",
    "\n",
    "    # Drift\n",
    "    def f(self, t, y):\n",
    "        y = y.view(-1, *self.input_size)\n",
    "        if len(t.shape) == len(y.shape):\n",
    "            x = torch.cat([y, t], 1)\n",
    "        else:\n",
    "            x = torch.cat([y, t.repeat(y.shape[0])[:, None]], 1)\n",
    "        return self.drift(x).flatten(start_dim=1) + self.score(x).flatten(start_dim=1)\n",
    "\n",
    "    # Diffusion\n",
    "    def g(self, t, y):\n",
    "        return torch.ones_like(y) * self.sigma\n",
    "\n",
    "\n",
    "sde = SDE(sf2m_model, sf2m_score_model, input_size=(2,), sigma=sigma)\n",
    "with torch.no_grad():\n",
    "    sde_traj = torchsde.sdeint(\n",
    "        sde,\n",
    "        x0.to(device),\n",
    "        ts=torch.linspace(0, n_times - 1, 400, device=device),\n",
    "    ).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_trajectories(sde_traj.detach().cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(\n",
    "    {\n",
    "        \"model\": sf2m_model,\n",
    "        \"score_model\": sf2m_score_model,\n",
    "        \"optimizer\": sf2m_optimizer,\n",
    "    },\n",
    "    f\"{savedir}/sf2m_single_cell_sigma_{sigma}.pt\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evolution of trajectories between SF2M's probability flow ODE (SB-CFM) and SF2M"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Starting from a cell at day 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node = NeuralODE(torch_wrapper(sf2m_model), solver=\"euler\", sensitivity=\"adjoint\")\n",
    "x0 = torch.from_numpy(X[0][:1000]).float()\n",
    "with torch.no_grad():\n",
    "    traj = node.trajectory(\n",
    "        x0[2].repeat(20).view(20, 2).to(device),\n",
    "        t_span=torch.linspace(0, n_times - 1, 400, device=device),\n",
    "    ).cpu()\n",
    "# plot_trajectories(traj.cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    sde_traj = torchsde.sdeint(\n",
    "        sde,\n",
    "        x0[2].repeat(20).view(20, 2).to(device),\n",
    "        ts=torch.linspace(0, n_times - 1, 400, device=device),\n",
    "    ).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "traj = traj.detach().cpu().numpy()\n",
    "sde_traj = sde_traj.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 2000\n",
    "f, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n",
    "\n",
    "ax1.scatter(traj[0, :n, 0], traj[0, :n, 1], s=100, alpha=1, marker=\"d\", c=\"brown\", zorder=3)\n",
    "ax1.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=100, alpha=1, marker=\"d\", c=\"blue\", zorder=3)\n",
    "\n",
    "for i in range(5):\n",
    "    ax1.plot(traj[:, i, 0], traj[:, i, 1], alpha=0.9, c=\"black\", zorder=2)\n",
    "\n",
    "scprep.plot.scatter(\n",
    "    adata.obsm[\"X_phate_standardized\"][:, 0],\n",
    "    adata.obsm[\"X_phate_standardized\"][:, 1],\n",
    "    c=adata.obs[\"sample_labels\"],\n",
    "    ax=ax1,\n",
    "    legend=None,\n",
    ")\n",
    "\n",
    "ax1.set_xticks([])\n",
    "ax1.set_yticks([])\n",
    "ax1.set_title(\"Trajectories from a given sample with SB-CFM\", fontsize=15)\n",
    "ax1.legend([r\"$x_0$\", r\"$x_1$\", r\"$x_t \\mid x_0$\"])\n",
    "\n",
    "\n",
    "ax2.scatter(\n",
    "    sde_traj[0, :n, 0], sde_traj[0, :n, 1], s=100, alpha=1, marker=\"d\", c=\"brown\", zorder=3\n",
    ")\n",
    "ax2.scatter(\n",
    "    sde_traj[-1, :n, 0], sde_traj[-1, :n, 1], s=100, alpha=1, marker=\"d\", c=\"blue\", zorder=3\n",
    ")\n",
    "\n",
    "for i in range(5):\n",
    "    ax2.plot(sde_traj[:, i, 0], sde_traj[:, i, 1], alpha=0.9, c=\"black\", zorder=2)\n",
    "\n",
    "scprep.plot.scatter(\n",
    "    adata.obsm[\"X_phate_standardized\"][:, 0],\n",
    "    adata.obsm[\"X_phate_standardized\"][:, 1],\n",
    "    c=adata.obs[\"sample_labels\"],\n",
    "    ax=ax2,\n",
    "    legend=None,\n",
    ")\n",
    "\n",
    "ax2.legend([r\"$x_0$\", r\"$x_1$\", r\"$x_t \\mid x_0$\"])\n",
    "\n",
    "ax2.set_xticks([])\n",
    "ax2.set_yticks([])\n",
    "ax2.set_title(\n",
    "    r\"Trajectories from a given sample with SF2M $(\\sigma={})$\".format(sigma), fontsize=15\n",
    ")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"conditonal_trajectory_sf2m_sigma_{}.png\".format(sigma))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Starting from a cell at an older time (day 25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node = NeuralODE(torch_wrapper(sf2m_model), solver=\"euler\", sensitivity=\"adjoint\")\n",
    "x0 = torch.from_numpy(X[1][:1000]).float()\n",
    "with torch.no_grad():\n",
    "    traj = node.trajectory(\n",
    "        x0[1].repeat(15).view(15, 2).to(device),\n",
    "        t_span=torch.linspace(1, n_times - 1, 300, device=device),\n",
    "    ).cpu()\n",
    "\n",
    "with torch.no_grad():\n",
    "    sde_traj = torchsde.sdeint(\n",
    "        sde,\n",
    "        x0[1].repeat(15).view(15, 2).to(device),\n",
    "        ts=torch.linspace(1, n_times - 1, 300, device=device),\n",
    "    ).cpu()\n",
    "\n",
    "traj = traj.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 2000\n",
    "f, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n",
    "\n",
    "\n",
    "ax1.scatter(traj[0, :n, 0], traj[0, :n, 1], s=100, alpha=1, marker=\"d\", c=\"brown\", zorder=3)\n",
    "ax1.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=100, alpha=1, marker=\"d\", c=\"blue\", zorder=3)\n",
    "\n",
    "for i in range(5):\n",
    "    ax1.plot(traj[:, i, 0], traj[:, i, 1], alpha=0.9, c=\"black\", zorder=2)\n",
    "\n",
    "scprep.plot.scatter(\n",
    "    adata.obsm[\"X_phate_standardized\"][:, 0],\n",
    "    adata.obsm[\"X_phate_standardized\"][:, 1],\n",
    "    c=adata.obs[\"sample_labels\"],\n",
    "    ax=ax1,\n",
    "    legend=None,\n",
    ")\n",
    "\n",
    "ax1.set_xticks([])\n",
    "ax1.set_yticks([])\n",
    "ax1.set_title(\"Trajectories from a given sample with SB-CFM\", fontsize=15)\n",
    "ax1.legend([r\"$x_0$\", r\"$x_1$\", r\"$X_t \\mid X_0$\"])\n",
    "\n",
    "\n",
    "ax2.scatter(\n",
    "    sde_traj[0, :n, 0], sde_traj[0, :n, 1], s=100, alpha=1, marker=\"d\", c=\"brown\", zorder=3\n",
    ")\n",
    "ax2.scatter(\n",
    "    sde_traj[-1, :n, 0], sde_traj[-1, :n, 1], s=100, alpha=1, marker=\"d\", c=\"blue\", zorder=3\n",
    ")\n",
    "\n",
    "for i in range(5):\n",
    "    ax2.plot(sde_traj[:, i, 0], sde_traj[:, i, 1], alpha=0.9, c=\"black\", zorder=2)\n",
    "\n",
    "\n",
    "scprep.plot.scatter(\n",
    "    adata.obsm[\"X_phate_standardized\"][:, 0],\n",
    "    adata.obsm[\"X_phate_standardized\"][:, 1],\n",
    "    c=adata.obs[\"sample_labels\"],\n",
    "    ax=ax2,\n",
    "    legend=None,\n",
    ")\n",
    "\n",
    "ax2.legend([r\"$x_0$\", r\"$x_1$\", r\"$x_t \\mid x_0$\"])\n",
    "\n",
    "ax2.set_xticks([])\n",
    "ax2.set_yticks([])\n",
    "ax2.set_title(\n",
    "    r\"Trajectories from a given sample with SF2M $(\\sigma={})$\".format(sigma), fontsize=15\n",
    ")\n",
    "# plt.axis(\"off\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"conditonal_trajectory_sf2m_sigma_{}_from_later_time.png\".format(sigma))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### As SB-CFM is a deterministic process, it only ends up at one final location contrary to our model SF2M that models a stochastic process and produces different outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchcfm2",
   "language": "python",
   "name": "torchcfm2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}