{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06b8ac32-5c46-41a8-8e75-6de5fa3962c9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "num_threads = \"4\"\n",
    "os.environ[\"OMP_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"OPENBLAS_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"MKL_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"VECLIB_MAXIMUM_THREADS\"] = num_threads\n",
    "os.environ[\"NUMEXPR_NUM_THREADS\"] = num_threads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e47efd2c-0fd9-4acc-9fa0-06b94c06f3ef",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "import importlib\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import linear_model\n",
    "import dcor\n",
    "import torchsde\n",
    "import pandas as pd\n",
    "from torchdiffeq import odeint\n",
    "import sklearn as sk\n",
    "from sklearn import preprocessing\n",
    "torch.set_default_dtype(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a148d830-24d2-463c-90eb-9605dcb5d190",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data = torch.load(\"data_cellcycle_pca.pkl\", weights_only=False)\n",
    "_idx = np.argsort(data['t_idx'])\n",
    "data['x']=data['x'][_idx]\n",
    "data['v']=data['v'][_idx]\n",
    "data['t_idx'] = data['t_idx'][_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d107011a-854f-4acb-af7b-fcc8975aee9b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "d = data['x'].shape[-1]\n",
    "T = len(np.unique(data['t_idx']))\n",
    "ts = np.sort(np.unique(data['t_idx'])) / (T-1)\n",
    "\n",
    "import fm\n",
    "from tqdm import tqdm\n",
    "import copy\n",
    "\n",
    "scale_factors = torch.linspace(0, 100, 11)\n",
    "otfms = [fm.LinearEntropicOTFM(data['x'], \n",
    "                      data['t_idx'], \n",
    "                      ts = ts,\n",
    "                      sigma = 0.5,\n",
    "                      A = data['A'] * s,\n",
    "                      mu = data['mu'],\n",
    "                      T = T,\n",
    "                      dim = d,\n",
    "                      device = torch.device('cpu')\n",
    "                  )  for s in tqdm(scale_factors)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a67b8b2b-ff30-42f7-a349-fa0fd6ede46b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "(data['x'] - data['x'].mean(0)).norm(dim = 1).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2607e1f8-c788-45c5-83a4-1cd1cca7f224",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data['v'].norm(dim = 1).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d958a973-f02e-44a9-8a90-ea7b5204030f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from scipy.interpolate import griddata\n",
    "import numpy as np\n",
    "\n",
    "def plot_stream_vectorfield(x, v, ax = None, x_min = None, x_max = None, y_min = None, y_max = None, **kwargs):\n",
    "    x_min = x[:, 0].min() if x_min is None else x_min\n",
    "    x_max = x[:, 0].max() if x_max is None else x_max\n",
    "    y_min = x[:, 0].min() if y_min is None else y_min\n",
    "    y_max = x[:, 0].max() if y_max is None else y_max\n",
    "    _x, _y = np.meshgrid(np.linspace(x_min, x_max, 15), np.linspace(y_min, y_max, 15))\n",
    "    _u = griddata((x[:, 0], x[:, 1]), v[:, 0], (_x, _y), method='linear', fill_value = 0)\n",
    "    _v = griddata((x[:, 0], x[:, 1]), v[:, 1], (_x, _y), method='linear', fill_value = 0)\n",
    "    if ax is None:\n",
    "        plt.streamplot(_x, _y, _u, _v, **kwargs)\n",
    "    else:\n",
    "        ax.streamplot(_x, _y, _u, _v, **kwargs)\n",
    "\n",
    "plt.figure(figsize = (6, 3))\n",
    "plt.subplot(1, 2, 1)\n",
    "v_ref = (data['x'] - data['mu']) @ data['A'].T\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], c = data['t_idx'], alpha = 0.1, s = 100, rasterized = True);\n",
    "plot_stream_vectorfield(data['x'][:, range(2)], data['v'][:, range(2)], color = 'k', linewidth = 0.5)\n",
    "plt.xlabel(\"PCA1\"); plt.ylabel(\"PCA2\")\n",
    "plt.axis('off')\n",
    "plt.colorbar()\n",
    "plt.title(\"$v_{\\\\rm{dynamo}}$\")\n",
    "plt.subplot(1, 2, 2)\n",
    "v_ref = (data['x'] - data['mu']) @ data['A'].T\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], c = data['t_idx'], alpha = 0.1, s = 100, rasterized = True);\n",
    "plot_stream_vectorfield(data['x'][:, range(2)], v_ref[:, range(2)], color = 'k', linewidth = 0.5)\n",
    "plt.title(\"$v_{\\\\rm{ref}}$\")\n",
    "plt.xlabel(\"PCA1\"); plt.ylabel(\"PCA2\")\n",
    "plt.axis('off')\n",
    "plt.colorbar()\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/cellcycle_pca_vf.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbc7e5ad-94fa-48e5-83a8-54d6fc4efcd5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class OrnsteinUhlenbeck(torch.nn.Module):\n",
    "    noise_type = 'diagonal'\n",
    "    sde_type = 'ito'\n",
    "    def __init__(self, A, sigma):\n",
    "        super().__init__()\n",
    "        self.A = A\n",
    "        self.sigma = sigma\n",
    "    def f(self, t, x):\n",
    "        return x @ self.A.T\n",
    "    # Diffusion\n",
    "    def g(self, t, x):\n",
    "        return torch.ones_like(x)*self.sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "766af7b4-f536-4fbc-8666-762aebfb9c37",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Try Gaussian approximation\n",
    "Xs = [data['x'][data['t_idx'] == i] for i in range(T)]\n",
    "means = [torch.mean(x, 0) for x in Xs]\n",
    "covs = [torch.cov(x.T) for x in Xs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6c31365-09f8-4176-aa73-df1b61a4f8e9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from scipy.stats import multivariate_normal\n",
    "import matplotlib.cm\n",
    "\n",
    "def get_cmap(solid_color):\n",
    "    cdict = {\n",
    "        'red':   [(0.0, 1.0, 1.0),  # Start at white (1.0, 1.0, 1.0)\n",
    "                  (1.0, solid_color[0], solid_color[0])],  # End at solid color\n",
    "\n",
    "        'green': [(0.0, 1.0, 1.0),\n",
    "                  (1.0, solid_color[1], solid_color[1])],\n",
    "\n",
    "        'blue':  [(0.0, 1.0, 1.0),\n",
    "                  (1.0, solid_color[2], solid_color[2])]\n",
    "    }\n",
    "    return matplotlib.colors.LinearSegmentedColormap('WhiteToBlue', cdict)\n",
    "\n",
    "# x = np.linspace(-0.5, 3.5, 100)\n",
    "# y = np.linspace(-0.5, 3.5, 100)\n",
    "x = np.linspace(-10, 10, 100)\n",
    "y = np.linspace(-10, 10, 100)\n",
    "X, Y = np.meshgrid(x, y)\n",
    "_X = np.vstack([X.reshape(-1), Y.reshape(-1)])\n",
    "pos = np.dstack((X, Y))\n",
    "\n",
    "def plot_bivariate(mean, cov, cm):\n",
    "    rv = multivariate_normal(mean, cov)\n",
    "    Z = rv.pdf(pos)\n",
    "    plt.contour(X, Y, Z, levels=5, cmap=cm)\n",
    "    \n",
    "def plot_bivariate_3d(mean, cov, cm):\n",
    "    rv = multivariate_normal(mean, cov)\n",
    "    Z = rv.pdf(pos)\n",
    "    plt.contour(X, Y, Z, levels=5, cmap=cm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "866dc6d2-a1bd-403b-9c3c-0fd372c9d8e8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "importlib.reload(fm)\n",
    "k=0 # dimension to show \n",
    "offset = 4\n",
    "t = torch.linspace(0, 1, 8)\n",
    "otfm = otfms[5]\n",
    "for j in range(T-offset):\n",
    "    plt.figure(figsize = (8, 4))\n",
    "    plt.subplot(1, 2, 1)\n",
    "    _otfm = fm.LinearEntropicOTFM(otfm.x, \n",
    "                          otfm.t_idx, \n",
    "                          ts = otfm.ts,\n",
    "                          sigma = otfm.sigma * offset**0.5,\n",
    "                          A = otfm.A * offset,\n",
    "                          mu = otfm.mu,\n",
    "                          T = otfm.T,\n",
    "                          dim = otfm.dim,\n",
    "                          device = otfm.device\n",
    "                      )\n",
    "    gsb = fm.GaussianOUSB(_otfm.bm, _otfm)\n",
    "    mean0, mean1 = means[j], means[j+offset]\n",
    "    var0, var1 = covs[j], covs[j+offset]\n",
    "    sb_means, sb_vars, S_t, d_sb_means = gsb.evaluate(t, mean0, mean1, var0, var1)\n",
    "    for i in range(len(t)):\n",
    "        cm = get_cmap(matplotlib.cm.brg(t[i].item()))\n",
    "        plot_bivariate(sb_means[i].flatten()[k:k+2], sb_vars[i].reshape(d, d)[k:k+2, :][:, k:k+2], cm = cm)\n",
    "    idx = (data[\"t_idx\"] == j); plt.scatter(data['x'][idx, k], data['x'][idx, k+1], alpha = 0.5, marker = \"^\", color = 'blue', label = \"$p_{i-1}$\")\n",
    "    # idx = (data[\"t_idx\"] == j+offset//2); plt.scatter(data['x'][idx, k], data['x'][idx, k+1], alpha = 0.5, marker = \"^\", color = 'red', label = \"$p_{i}$\")\n",
    "    idx = (data[\"t_idx\"] == j+offset); plt.scatter(data['x'][idx, k], data['x'][idx, k+1], alpha = 0.5, marker = \"^\", color = 'green', label = \"$p_{i+1}$\")\n",
    "    plt.scatter(data['x'][:, k], data['x'][:, k+1], alpha = 0.1, marker = \"o\", color = 'grey', zorder = -100)\n",
    "    plt.legend(); plt.xlabel(\"$x_1$\"); plt.ylabel(\"$x_2$\")\n",
    "    plt.title(\"Ornstein-Uhlenbeck reference\")\n",
    "    plt.subplot(1, 2, 2)\n",
    "    _otfm = fm.LinearEntropicOTFM(otfm.x, \n",
    "                          otfm.t_idx, \n",
    "                          ts = otfm.ts,\n",
    "                          sigma = otfm.sigma * offset**0.5,\n",
    "                          A = otfm.A * 0,\n",
    "                          mu = otfm.mu,\n",
    "                          T = otfm.T,\n",
    "                          dim = otfm.dim,\n",
    "                          device = otfm.device\n",
    "                      )\n",
    "    gsb = fm.GaussianOUSB(_otfm.bm, _otfm)\n",
    "    mean0, mean1 = means[j], means[j+offset]\n",
    "    var0, var1 = covs[j], covs[j+offset]\n",
    "    sb_means, sb_vars, S_t, d_sb_means = gsb.evaluate(t, mean0, mean1, var0, var1)\n",
    "    for i in range(len(t)):\n",
    "        cm = get_cmap(matplotlib.cm.brg(t[i].item()))\n",
    "        plot_bivariate(sb_means[i].flatten()[k:k+2], sb_vars[i].reshape(d, d)[k:k+2, :][:, k:k+2], cm = cm)\n",
    "    idx = (data[\"t_idx\"] == j); plt.scatter(data['x'][idx, k], data['x'][idx, k+1], alpha = 0.5, marker = \"^\", color = 'blue', label = \"$p_{i-1}$\")\n",
    "    # idx = (data[\"t_idx\"] == j+offset//2); plt.scatter(data['x'][idx, k], data['x'][idx, k+1], alpha = 0.5, marker = \"^\", color = 'red', label = \"$p_{i}$\")\n",
    "    idx = (data[\"t_idx\"] == j+offset); plt.scatter(data['x'][idx, k], data['x'][idx, k+1], alpha = 0.5, marker = \"^\", color = 'green', label = \"$p_{i+1}$\")\n",
    "    plt.scatter(data['x'][:, k], data['x'][:, k+1], alpha = 0.1, marker = \"o\", color = 'grey', zorder = -100)\n",
    "    plt.legend(); plt.xlabel(\"$x_1$\"); plt.ylabel(\"$x_2$\")\n",
    "    plt.title(\"Wiener reference\")\n",
    "    # plt.suptitle(f\"i = {j+1}\")\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"../../figures/cellcycle_pca_gsb_interp.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "687a7661-a104-41e9-91e0-5013560a6675",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "scale_factors[5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dda8995-a338-4ffc-9b63-70af24f8eb66",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "res = {}\n",
    "for s in tqdm(scale_factors):\n",
    "    _otfm = fm.LinearEntropicOTFM(otfm.x, \n",
    "                          otfm.t_idx, \n",
    "                          ts = otfm.ts,\n",
    "                          sigma = otfm.sigma * offset**0.5,\n",
    "                          A = data['A'] * offset * s,\n",
    "                          mu = otfm.mu,\n",
    "                          T = otfm.T,\n",
    "                          dim = otfm.dim,\n",
    "                          device = otfm.device\n",
    "                      )\n",
    "    gsb = fm.GaussianOUSB(_otfm.bm, _otfm)\n",
    "    mean0, mean1 = means[j], means[j+offset]\n",
    "    var0, var1 = covs[j], covs[j+offset]\n",
    "    sb_means, sb_vars, S_t, d_sb_means = gsb.evaluate(torch.tensor(ts).float(), mean0, mean1, var0, var1)\n",
    "    res[s.item()] = [fm.bures_wasserstein(sb_means[i].squeeze(), means[i], sb_vars[i], covs[i]).item() for i in range(T)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5733e7be-796b-4835-bc8f-83384cf4512e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import seaborn as sb\n",
    "df = pd.DataFrame(res)\n",
    "plt.figure(figsize = (7, 2))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.title(\"Interpolation over time\")\n",
    "_df = df.reset_index().melt(id_vars = 'index', value_vars = df.columns).rename(columns={'index' : 't', 'variable' : 'scale'}, inplace = False)\n",
    "ax=sb.lineplot(_df, x = \"t\", y = \"value\", hue = \"scale\", palette = \"viridis\", marker = \"o\")\n",
    "sb.move_legend(ax, \"upper left\", bbox_to_anchor=(1, 1))\n",
    "plt.ylabel(\"$W_2^2$\")\n",
    "plt.subplot(1, 2, 2)\n",
    "sb.barplot(df.mean(0), color = 'lightgrey')\n",
    "plt.xticks(ticks = range(len(scale_factors))[::2], labels = scale_factors.int().numpy()[::2])\n",
    "plt.ylabel(\"$W_2^2$\")\n",
    "plt.xlabel(\"scale\")\n",
    "plt.title(\"Average error\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/cellcycle_interp_error.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e49968b-dc27-4ebc-9eb2-bc66a0cee796",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Train flow matching\n",
    "hidden_sizes_score = [64, 64, 64]\n",
    "hidden_sizes_flow = [64, 64, 64]\n",
    "\n",
    "models = {}\n",
    "\n",
    "for s in [0, 50.0]:\n",
    "    print(f\"Training OT-FM for s = {s}\")\n",
    "    otfm = otfms[np.where(scale_factors == s)[0][0]]\n",
    "    s_model = fm.MLP(d = d, hidden_sizes = hidden_sizes_score, time_varying=True, activation = torch.nn.ReLU)\n",
    "    v_model = fm.MLP(d = d, hidden_sizes = hidden_sizes_flow, time_varying=True, activation = torch.nn.ReLU)\n",
    "    optim = torch.optim.AdamW(list(s_model.parameters()) + list(v_model.parameters()), 1e-2)\n",
    "    alpha = 0.5\n",
    "    for i in tqdm(range(1_000)):\n",
    "        _x, _s, _u, _t, _t_orig = otfm.sample_bridges_flows(batch_size = 64)\n",
    "        optim.zero_grad()\n",
    "        s_fit = s_model(_t, _x)\n",
    "        v_fit = v_model(_t, _x)\n",
    "        L_score = torch.mean(((_t_orig * (1-_t_orig)) * (s_fit - _s))**2) * otfm.sigma**2\n",
    "        L_flow = torch.mean((_t_orig * (1-_t_orig)*(v_fit - _u))**2)\n",
    "        # L_flow = torch.mean((v_fit - _u)**2)\n",
    "        L = (1-alpha)*L_score + alpha*L_flow\n",
    "        if i % 250 == 0:\n",
    "            print(L_score.item(), L_flow.item())\n",
    "        L.backward()\n",
    "        optim.step()\n",
    "    models[s] = (s_model, v_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5b5db8e-4952-4246-82b8-83b9a226e69f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (10, 3))\n",
    "plt.subplot(1, 4, 1)\n",
    "v_ref = (data['x'] - data['mu']) @ data['A'].T\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], c = data['t_idx'], alpha = 0.1, s = 100, rasterized = True);\n",
    "plot_stream_vectorfield(data['x'][:, range(2)], data['v'][:, range(2)], color = 'k', linewidth = 0.5)\n",
    "plt.xlabel(\"PCA1\"); plt.ylabel(\"PCA2\"); plt.axis('off')\n",
    "# plt.colorbar()\n",
    "\n",
    "plt.title(\"$v_{\\\\rm{dynamo}}$\")\n",
    "plt.subplot(1, 4, 2)\n",
    "v_ref = (data['x'] - data['mu']) @ data['A'].T\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], c = data['t_idx'], alpha = 0.1, s = 100, rasterized = True);\n",
    "plot_stream_vectorfield(data['x'][:, range(2)], v_ref[:, range(2)], color = 'k', linewidth = 0.5)\n",
    "plt.title(\"$v_{\\\\rm{ref}}$\")\n",
    "plt.xlabel(\"PCA1\"); plt.ylabel(\"PCA2\"); plt.axis('off')\n",
    "# plt.colorbar()\n",
    "\n",
    "plt.subplot(1, 4, 4)\n",
    "(v_model, s_model), sigma = models[0.], otfm.sigma\n",
    "sde = fm.SDE(lambda t, x: v_model(t, x) + sigma**2 / 2 * s_model(t, x), sigma)\n",
    "with torch.no_grad():\n",
    "    v_sb = torch.vstack([sde.f(torch.scalar_tensor(t), data['x'][data['t_idx'] == i, ...]) for i, t in enumerate(ts)])\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], c = data['t_idx'], alpha = 0.1, s = 100, rasterized = True);\n",
    "plot_stream_vectorfield(data['x'][:, range(2)], v_sb[:, range(2)], color = 'k', linewidth = 0.5)\n",
    "plt.title(\"$v_{\\\\rm{SB}}$ (BM ref.)\")\n",
    "plt.xlabel(\"PCA1\"); plt.ylabel(\"PCA2\"); plt.axis('off')\n",
    "# plt.colorbar()\n",
    "\n",
    "plt.subplot(1, 4, 3)\n",
    "(v_model, s_model), sigma = models[50.], otfm.sigma\n",
    "sde = fm.SDE(lambda t, x: v_model(t, x) + sigma**2 / 2 * s_model(t, x), sigma)\n",
    "with torch.no_grad():\n",
    "    v_sb_ou = torch.vstack([sde.f(torch.scalar_tensor(t), data['x'][data['t_idx'] == i, ...]) for i, t in enumerate(ts)])\n",
    "plt.scatter(data['x'][:, 0], data['x'][:, 1], c = data['t_idx'], alpha = 0.1, s = 100, rasterized = True);\n",
    "plot_stream_vectorfield(data['x'][:, range(2)], v_sb_ou[:, range(2)], color = 'k', linewidth = 0.5)\n",
    "plt.title(\"$v_{\\\\rm{SB}}$ (OU ref.)\")\n",
    "plt.xlabel(\"PCA1\"); plt.ylabel(\"PCA2\"); plt.axis('off')\n",
    "# plt.colorbar()\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/cellcycle_pca_vf_with_fm.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed569721-721c-45c2-b390-144fd1f4a76b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lfm",
   "language": "python",
   "name": "lfm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
