{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee2f7a9c-714a-4376-b553-5263436880da",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import anndata as ad\n",
    "import scanpy as sc\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import os\n",
    "import sklearn as sk\n",
    "from sklearn import preprocessing, decomposition\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import seaborn as sb\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.set_default_dtype(torch.float32)\n",
    "c = 0.5\n",
    "beta = 5.5\n",
    "seed = 1\n",
    "\n",
    "data = torch.load(f\"sim_HSC_N_500_T_10_c_{c}_beta_{beta}.pkl\", weights_only = False)\n",
    "data_nogrowth = torch.load(f\"sim_HSC_N_500_T_10_c_{c}_beta_0.pkl\", weights_only = False)\n",
    "adata = ad.AnnData(data['x'], {\"t_idx\" : data['t_idx']})\n",
    "sc.tl.pca(adata)\n",
    "sc.pl.scatter(adata, basis = \"pca\", color = \"t_idx\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f71ef0cf-0c32-4754-af32-d65ce32bf8e1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (3, 3))\n",
    "sb.barplot(pd.Series(data[\"t_idx\"]).value_counts().sort_index() / 500)\n",
    "plt.xlabel(\"Timepoint\"); plt.ylabel(\"Cells\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "957792df-01d8-46e9-9acb-fe452890a168",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "scaler_op = sk.preprocessing.StandardScaler().fit(adata.X)\n",
    "ts = torch.tensor(adata.obs.t_idx) \n",
    "Xs = adata.X # scaler_op.transform(adata.X)\n",
    "pca_op = sk.decomposition.PCA().fit(data_nogrowth['x'])\n",
    "Xs_pca = pca_op.transform(Xs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "964ace57-300c-4320-ba60-bfb077042e10",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (3.5, 3.5))\n",
    "ax = fig.add_subplot(1, 1, 1, projection='3d')\n",
    "ax.view_init(30, -100)\n",
    "scatter = ax.scatter(Xs_pca[:, 0], Xs_pca[:, 1], Xs_pca[:, 2], c=ts, cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "ax.set_xlabel('X')\n",
    "ax.set_ylabel('Y')\n",
    "ax.set_zlabel('Z')\n",
    "plt.axis('off')\n",
    "plt.savefig(\"../../figures/boolode_HSC_pca_3d.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60f9a14f-5276-40e2-9520-0bae2103a66d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "import models\n",
    "from torch import optim\n",
    "\n",
    "ts = torch.tensor(np.sort(np.unique(adata.obs.t_idx)), dtype = torch.float32)\n",
    "ts /= ts.max()\n",
    "dim = Xs.shape[1]\n",
    "T = adata.obs.t_idx.max()+1\n",
    "\n",
    "X = [torch.tensor(Xs[adata.obs.t_idx == i, ...], device = device, dtype = torch.float32) for i in range(T)]\n",
    "\n",
    "sigmas = torch.linspace(1, -2, 5, device = device).exp()\n",
    "hidden_sizes = [128, 128, 128]\n",
    "s = models.NCScoreFunc(d = dim, hidden_sizes = hidden_sizes, activation = torch.nn.ReLU, time_dependent = True).to(device)\n",
    "s.load_state_dict(torch.load(f\"weights/params_NCScoreFunc_default_c_{c}_seed_{seed}_final.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8216d98-c976-4f08-8ce0-24087b8c04d6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "P = torch.vstack([torch.eye(2), torch.zeros(dim-2, 2)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2742d7d-60f6-49b1-b5a8-eb5fd21f9ee1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "s.to(device)\n",
    "samplers = [models.LangevinSampler(lambda x, sigma, _t = torch.scalar_tensor(_s).to(device): s(_t, x, sigma), \n",
    "                      torch.randn(1_000, dim).to(device), \n",
    "                      sigmas = sigmas, dt = 3e-3, n_iter = 1000) for _s in ts]\n",
    "\n",
    "x_sample = torch.vstack([s.sample().cpu()[None, ...] for s in samplers])\n",
    "X_all = torch.vstack(X).cpu()\n",
    "s.cpu()\n",
    "x_min, x_max = (X_all[:, 0]).min()-1, (X_all[:, 0]).max()+1\n",
    "y_min, y_max = (X_all[:, 1]).min()-1, (X_all[:, 1]).max()+1\n",
    "plt.figure(figsize = (10, 5))\n",
    "for i in range(len(X)):\n",
    "    x, y = torch.meshgrid(torch.linspace(x_min, x_max, 15), torch.linspace(y_min, y_max, 15))\n",
    "    _x = torch.vstack([x.flatten(), y.flatten()]).T\n",
    "    t = torch.scalar_tensor(ts[i])\n",
    "    plt.subplot(2,5, i+1)\n",
    "    with torch.no_grad():\n",
    "        _v = s(t, torch.tensor(_x @ P.T, dtype = torch.float32), sigmas[-1].cpu()) @ P\n",
    "        # _v = s(t, torch.tensor(_x @ P.T, dtype = torch.float32)) @ P\n",
    "    # plt.contourf(x, y, torch.linalg.norm(_v, dim = 1).reshape(x.shape), levels = 20, cmap = \"bone\");\n",
    "    plt.quiver(_x[:, 0], _x[:, 1], _v[:, 0], _v[:, 1], torch.linalg.norm(_v, dim = 1).reshape(x.shape), cmap='RdBu_r', scale = 1e3)\n",
    "    plt.scatter(X[i].cpu()[:, 0], X[i].cpu()[:, 1], s = 0.1, c = 'r', alpha = 1, zorder = 10)\n",
    "    plt.scatter(x_sample[i, :, 0].cpu(), x_sample[i, :, 1], s = 5, c = 'b', alpha = 0.1)\n",
    "    plt.xlim(x_min, x_max); plt.ylim(y_min, y_max)\n",
    "    plt.title(f\"t = {ts[i]:.2f}\")\n",
    "    plt.xlabel(\"x\"); plt.ylabel(\"x\")\n",
    "s.to(device)\n",
    "plt.suptitle(\"Learned score function\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/boolode_HSC_score_validation.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a9735d-baee-455c-b529-1d5728fcbbba",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (7, 3.5))\n",
    "ax = fig.add_subplot(1, 2, 1, projection='3d'); ax.view_init(30, -100)\n",
    "ts = torch.tensor(adata.obs.t_idx) \n",
    "scatter = ax.scatter(Xs_pca[:, 0], Xs_pca[:, 1], Xs_pca[:, 2], c=ts, cmap='Blues', alpha = 0.25, edgecolor = 'k', rasterized = True)\n",
    "plt.axis('off')\n",
    "ax = fig.add_subplot(1, 2, 2, projection='3d'); ax.view_init(30, -100)\n",
    "x_sample_pca = pca_op.transform(x_sample.reshape(-1, dim))\n",
    "ts = torch.hstack([torch.full((x_sample.shape[1], ), i) for i in range(x_sample.shape[0])])\n",
    "ax.scatter(x_sample_pca[:, 0], x_sample_pca[:, 1], x_sample_pca[:, 2], alpha = 0.1, s = 10, c = ts, cmap = \"Blues\", edgecolor = 'k', rasterized = True)\n",
    "plt.axis('off')\n",
    "plt.savefig(\"../../figures/boolode_HSC_score_validation_pca3d.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ca023b0-1d58-494c-a97c-1a45dd64fd91",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "from tqdm import tqdm\n",
    "import geomloss\n",
    "import torchdiffeq\n",
    "from torchdiffeq import odeint\n",
    "import math\n",
    "import utils, losses, train\n",
    "import importlib\n",
    "importlib.reload(models); importlib.reload(utils); importlib.reload(losses); importlib.reload(train)\n",
    "\n",
    "D = 0.5**2\n",
    "ts = torch.linspace(0, 1, len(X)).to(device)\n",
    "m_ratios = torch.tensor([x.shape[0] / X[0].shape[0] for x in X])\n",
    "hidden_sizes = [128, 128, 128]\n",
    "\n",
    "v_upfi_mult = models.MultiplicativeNoiseFlowGrowth(dim, lambda t, x, sigma : s(t, x, sigma), D, sigmas[-1], \n",
    "                                                  kwargs_u = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : hidden_sizes}, \n",
    "                                                  kwargs_v = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : hidden_sizes}, \n",
    "                                                  kwargs_g = {'time_dependent' : False, 'hidden_sizes' : hidden_sizes[:1]}).to(device)\n",
    "v_upfi_mult.load_state_dict(torch.load(f'weights/params_MULT_UPFI_ODEFlowGrowth_default_c_{c}_seed_{seed}_final.pt'))\n",
    "\n",
    "v_pfi_mult = models.MultiplicativeNoiseFlow(dim, lambda t, x, sigma : s(t, x, sigma), D, sigmas[-1], \n",
    "                                                  kwargs_u = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : hidden_sizes}, \n",
    "                                                  kwargs_v = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : hidden_sizes}).to(device)\n",
    "v_pfi_mult.load_state_dict(torch.load(f'weights/params_MULT_PFI_VectorField_default_c_{c}_seed_{seed}_final.pt'))\n",
    "\n",
    "v_upfi = models.ODEFlowGrowth(d = dim, kwargs_v = {'hidden_sizes' : hidden_sizes, 'time_dependent' : False}, \n",
    "                                       kwargs_g = {'hidden_sizes' : hidden_sizes[:1], 'time_dependent' : False}).to(device)\n",
    "v_upfi.load_state_dict(torch.load(f'weights/params_UPFI_ODEFlowGrowth_default_c_{c}_seed_{seed}_final.pt'))\n",
    "\n",
    "\n",
    "v_pfi = models.VectorField(d = dim, hidden_sizes = hidden_sizes, time_dependent = True).to(device)\n",
    "v_pfi.load_state_dict(torch.load(f'weights/params_PFI_VectorField_default_c_{c}_seed_{seed}_final.pt'))\n",
    "\n",
    "x0_mass = utils.sample_batch_upfi(X, m_ratios.to(device), batch_size = 1024, replacement=True, add_noise=False)[0]\n",
    "x0 = x0_mass[..., 1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f06e3bee-e316-4d4f-88ca-5147ebb545e1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torchsde\n",
    "v_upfi_mult.to(device); v_pfi_mult.to(device); v_pfi.to(device); v_upfi.to(device);\n",
    "sde_pfi = models.SDE(lambda t, x: v_pfi(t, x), sigma = D**0.5)\n",
    "sde_upfi = models.SDE(lambda t, x: v_upfi(t, x), sigma = torch.cat([torch.tensor([0, ]), torch.full((dim, ), D**0.5)]).to(device))\n",
    "sde_pfi_mult = models.MultiplicativeNoiseSDE(v_pfi_mult.u, v_pfi_mult.v, sigma = D**0.5)\n",
    "sde_upfi_mult = models.MultiplicativeNoiseSDE(v_upfi_mult.u, v_upfi_mult.v, sigma = D**0.5)\n",
    "with torch.no_grad():\n",
    "    xs_t_pfi = torchsde.sdeint(sde_pfi, x0, ts, method = \"euler\").cpu()\n",
    "    xs_t_upfi = torchsde.sdeint(sde_upfi, x0_mass, ts, method = \"euler\").cpu()[..., 1:]\n",
    "    xs_t_pfi_mult = torchsde.sdeint(sde_pfi_mult, x0, ts, method = \"euler\").cpu()\n",
    "    xs_t_upfi_mult = torchsde.sdeint(sde_upfi_mult, x0, ts, method = \"euler\").cpu()\n",
    "    \n",
    "xs_t_pfi_ = xs_t_pfi.reshape(-1, xs_t_pfi.shape[-1]).cpu()\n",
    "xs_t_upfi_ = xs_t_upfi.reshape(-1, xs_t_upfi.shape[-1]).cpu()\n",
    "xs_t_pfi_mult_ = xs_t_pfi_mult.reshape(-1, xs_t_pfi_mult.shape[-1]).cpu()\n",
    "xs_t_upfi_mult_ = xs_t_upfi_mult.reshape(-1, xs_t_upfi_mult.shape[-1]).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47ddf850-20c9-450b-83b8-b796a110044d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import seaborn as sb\n",
    "pca_op = sk.decomposition.PCA().fit(data_nogrowth['x'])\n",
    "Xs_pca = pca_op.transform(Xs)\n",
    "with torch.no_grad():\n",
    "    xs_t_pfi_pca_ = pca_op.transform(xs_t_pfi.reshape(-1, xs_t_pfi.shape[-1]).cpu())\n",
    "    xs_t_upfi_pca_ = pca_op.transform(xs_t_upfi.reshape(-1, xs_t_upfi.shape[-1]).cpu())\n",
    "    xs_t_pfi_mult_pca_ = pca_op.transform(xs_t_pfi_mult.reshape(-1, xs_t_pfi_mult.shape[-1]).cpu())\n",
    "    xs_t_upfi_mult_pca_ = pca_op.transform(xs_t_upfi_mult.reshape(-1, xs_t_upfi_mult.shape[-1]).cpu())\n",
    "    \n",
    "xs_t_pfi_pca = xs_t_pfi_pca_.reshape(xs_t_pfi.shape)\n",
    "xs_t_upfi_pca = xs_t_upfi_pca_.reshape(xs_t_upfi.shape)\n",
    "xs_t_pfi_mult_pca = xs_t_pfi_mult_pca_.reshape(xs_t_pfi_mult.shape)\n",
    "xs_t_upfi_mult_pca = xs_t_upfi_mult_pca_.reshape(xs_t_upfi_mult.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c11b5d0-6bb9-4751-899c-714ba94a21e4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "X_paths_pca = pca_op.transform(data['x_paths'].reshape(-1, data['x'].shape[1])).reshape(data['x_paths'].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88c8b6ca-ce91-491b-9bfa-de2df48a3f47",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "k = 0\n",
    "fig=plt.figure(figsize = (15, 3))\n",
    "ax = fig.add_subplot(1, 5, 1, projection='3d'); ax.view_init(30, -100)\n",
    "ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=adata.obs.t_idx, cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "for i in range(50):\n",
    "    ax.plot(X_paths_pca[i, :, k], X_paths_pca[i, :, k+1], X_paths_pca[i, :, k+2], c = 'r', alpha = 0.3, zorder = 100)\n",
    "ax.set_xlim(-8, 8); ax.set_ylim(-3, 3); ax.set_zlim(-5, 5); plt.axis(\"off\")\n",
    "plt.title(\"Data\")\n",
    "ax = fig.add_subplot(1, 5, 2, projection='3d'); ax.view_init(30, -100)\n",
    "ax.scatter(xs_t_upfi_pca_[:, k], xs_t_upfi_pca_[:, k+1], xs_t_upfi_pca_[:, k+2], c=ts.repeat_interleave(xs_t_upfi.shape[1]).cpu(), cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "ax.view_init(30, -100)\n",
    "for i in range(50):\n",
    "    ax.plot(xs_t_upfi_pca[:, i, k], xs_t_upfi_pca[:, i, k+1], xs_t_upfi_pca[:, i, k+2], c = 'r', alpha = 0.3, zorder = 100)\n",
    "ax.set_xlim(-8, 8); ax.set_ylim(-3, 3); ax.set_zlim(-5, 5); plt.axis(\"off\")\n",
    "plt.title(\"UPFI Additive\")\n",
    "ax = fig.add_subplot(1, 5, 4, projection='3d'); ax.view_init(30, -100)\n",
    "ax.scatter(xs_t_pfi_pca_[:, k], xs_t_pfi_pca_[:, k+1], xs_t_pfi_pca_[:, k+2], c=ts.repeat_interleave(xs_t_pfi.shape[1]).cpu(), cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "for i in range(50):\n",
    "    ax.plot(xs_t_pfi_pca[:, i, k], xs_t_pfi_pca[:, i, k+1], xs_t_pfi_pca[:, i, k+2], c = 'r', alpha = 0.3, zorder = 100)\n",
    "ax.set_xlim(-8, 8); ax.set_ylim(-3, 3); ax.set_zlim(-5, 5); plt.axis(\"off\")\n",
    "plt.title(\"PFI Additive\")\n",
    "ax = fig.add_subplot(1, 5, 3, projection='3d'); ax.view_init(30, -100)\n",
    "ax.scatter(xs_t_upfi_mult_pca_[:, k], xs_t_upfi_mult_pca_[:, k+1], xs_t_upfi_mult_pca_[:, k+2], c=ts.repeat_interleave(xs_t_upfi_mult.shape[1]).cpu(), cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "for i in range(50):\n",
    "    ax.plot(xs_t_upfi_mult_pca[:, i, k], xs_t_upfi_mult_pca[:, i, k+1], xs_t_upfi_mult_pca[:, i, k+2], c = 'r', alpha = 0.3, zorder = 100)\n",
    "ax.set_xlim(-8, 8); ax.set_ylim(-3, 3); ax.set_zlim(-5, 5); plt.axis(\"off\")\n",
    "plt.title(\"UPFI Mult\")\n",
    "ax = fig.add_subplot(1, 5, 5, projection='3d'); ax.view_init(30, -100)\n",
    "ax.scatter(xs_t_pfi_mult_pca_[:, k], xs_t_pfi_mult_pca_[:, k+1], xs_t_pfi_mult_pca_[:, k+2], c=ts.repeat_interleave(xs_t_pfi_mult.shape[1]).cpu(), cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "for i in range(50):\n",
    "    ax.plot(xs_t_pfi_mult_pca[:, i, k], xs_t_pfi_mult_pca[:, i, k+1], xs_t_pfi_mult_pca[:, i, k+2], c = 'r', alpha = 0.3, zorder = 100)\n",
    "ax.set_xlim(-8, 8); ax.set_ylim(-3, 3); ax.set_zlim(-5, 5); plt.axis(\"off\")\n",
    "plt.title(\"PFI Mult\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/boolode_HSC_sample_paths_pca3d_coarse.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d41c5ae3-b711-4f60-84fe-2c2f9f000861",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Points\n",
    "import sklearn as sk\n",
    "from sklearn import cluster, decomposition\n",
    "pca_op = sk.decomposition.PCA()\n",
    "y_true = pca_op.fit_transform(data['x'][data['t_idx'] == T-1, :])\n",
    "y_upfi = pca_op.transform(xs_t_upfi[-1, ...])\n",
    "y_pfi = pca_op.transform(xs_t_pfi[-1, ...])\n",
    "clust_op = sk.cluster.KMeans(n_clusters=4)\n",
    "clusts = clust_op.fit_predict(y_true)\n",
    "\n",
    "fig = plt.figure(figsize = (5, 5))\n",
    "ax = fig.add_subplot(1, 1, 1, projection='3d'); ax.view_init(30, -60)\n",
    "ax.scatter(y_true[:, 0], y_true[:, 1], y_true[:, 2], c = clusts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e922e8e-a2e5-4099-b90b-688a7c0fd9c5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "df = pd.DataFrame(\n",
    "    {\"UPFI\" : pd.value_counts(clust_op.predict(y_upfi)), \n",
    "     \"PFI\" : pd.value_counts(clust_op.predict(y_pfi)), \n",
    "     \"True\" : pd.value_counts(clust_op.predict(y_true))})\n",
    "df / df.sum(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41cad1ea-acd4-4008-b437-595eaefbae48",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Paths \n",
    "pca_op = sk.decomposition.PCA(n_components=10)\n",
    "y_true = pca_op.fit_transform(data['x_paths'].reshape(data['x_paths'].shape[0], -1))\n",
    "y_upfi = pca_op.transform(xs_t_upfi.permute((1, 0, 2)).reshape(xs_t_upfi.shape[1], -1))\n",
    "y_pfi = pca_op.transform(xs_t_pfi.permute((1, 0, 2)).reshape(xs_t_pfi.shape[1], -1))\n",
    "y_upfi_mult = pca_op.transform(xs_t_upfi_mult.permute((1, 0, 2)).reshape(xs_t_upfi_mult.shape[1], -1))\n",
    "y_pfi_mult = pca_op.transform(xs_t_pfi_mult.permute((1, 0, 2)).reshape(xs_t_pfi_mult.shape[1], -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97c725bb-30b1-4efa-ab5e-6af0990309fb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "clust_op = sk.cluster.KMeans(n_clusters=4)\n",
    "clusts = clust_op.fit_predict(y_true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1659fdd1-ccdc-4bbb-86ce-33aa2bd6d507",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (5, 5))\n",
    "ax = fig.add_subplot(1, 1, 1, projection='3d'); ax.view_init(30, -60)\n",
    "ax.scatter(y_true[:, 0], y_true[:, 1], y_true[:, 2], c = clusts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8f45c3c-7694-479a-a693-ee78b71bd272",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pd.DataFrame(\n",
    "    {\"UPFI\" : pd.value_counts(clust_op.predict(y_upfi)), \n",
    "     \"UPFI_mult\" : pd.value_counts(clust_op.predict(y_upfi_mult)), \n",
    "     \"PFI\" : pd.value_counts(clust_op.predict(y_pfi)), \n",
    "     \"PFI_mult\" : pd.value_counts(clust_op.predict(y_pfi_mult)), \n",
    "     \"True\" : pd.value_counts(clust_op.predict(y_true))})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03752bc4-4970-4932-82b7-306b471221c8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (5, 5))\n",
    "ax = fig.add_subplot(1, 1, 1, projection='3d'); ax.view_init(30, -60)\n",
    "ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=data['beta'], cmap='viridis', alpha = 0.1, s = 1)\n",
    "_idx = np.where(clusts == 3)[0]\n",
    "ax.scatter(X_paths_pca[_idx, -1, k], X_paths_pca[_idx, -1, k+1], X_paths_pca[_idx, -1, k+2], c = 'r', alpha = 0.3, zorder = 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ae9db5f-3a68-4c28-9e13-2408bb0115e9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import evals\n",
    "importlib.reload(evals)\n",
    "sb.barplot({'UPFI' : evals.energy_distance_paths(xs_t_upfi.permute((1, 0, 2)).numpy(), data['x_paths']),\n",
    "            'UPFI_mult' : evals.energy_distance_paths(xs_t_upfi_mult.permute((1, 0, 2)).numpy(), data['x_paths']),\n",
    "            'PFI' : evals.energy_distance_paths(xs_t_pfi.permute((1, 0, 2)).numpy(), data['x_paths']),\n",
    "            'PFI_mult' : evals.energy_distance_paths(xs_t_pfi_mult.permute((1, 0, 2)).numpy(), data['x_paths']),\n",
    "           }, palette = 'tab10')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bb2e3ef-c482-4f1c-a0e1-b66c7f8cdadc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "f_true = torch.tensor(data['f'], dtype = torch.float32)\n",
    "g_true = torch.tensor(data['g'], dtype = torch.float32)\n",
    "_x = torch.tensor(data['x'], dtype = torch.float32).to(device)\n",
    "with torch.no_grad():\n",
    "    u_est_upfi_mult = v_upfi_mult.u.net(_x).cpu()\n",
    "    v_est_upfi_mult = v_upfi_mult.v.net(_x).cpu()\n",
    "    u_est_pfi_mult = v_pfi_mult.u.net(_x).cpu()\n",
    "    v_est_pfi_mult = v_pfi_mult.v.net(_x).cpu()\n",
    "    vf_pfi = torch.vstack([v_pfi(ts[i], torch.tensor(data['x'][data['t_idx'] == i, :], dtype = torch.float32).to(device)) for i in range(T)]).cpu()\n",
    "    vf_upfi = torch.vstack([v_upfi.v_net(ts[i], torch.tensor(data['x'][data['t_idx'] == i, :], dtype = torch.float32).to(device)) for i in range(T)]).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9a2966e-8547-4005-ab48-3a9eb71042aa",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import evals, utils\n",
    "importlib.reload(utils)\n",
    "utils.cos_dist(u_est_upfi_mult - v_est_upfi_mult, f_true).mean(), utils.cos_dist(u_est_pfi_mult - v_est_pfi_mult, f_true).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2aa7b57-fbad-467d-8759-712b6d22be86",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "utils.cos_dist(vf_pfi, f_true).mean(), utils.cos_dist(vf_upfi, f_true).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "440fab42-2313-4ca5-8188-a3d4cabd362a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (5, 5))\n",
    "ax = fig.add_subplot(1, 1, 1, projection='3d'); ax.view_init(30, -100)\n",
    "ax.scatter(Xs_pca[:, 0], Xs_pca[:, 1], Xs_pca[:, 2], c=data['x'][:, 1], cmap='viridis', alpha = 0.5, s = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c91ff41c-c4c9-4269-9766-b83c1318a082",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sb.barplot({\"PFI\" : utils.l2_dist(vf_pfi, f_true),\n",
    "            \"PFI (Mult.)\" : utils.l2_dist(u_est_pfi_mult - v_est_pfi_mult, f_true),\n",
    "            \"UPFI\" : utils.l2_dist(vf_upfi, f_true),\n",
    "            \"UPFI (Mult.)\" : utils.l2_dist(u_est_upfi_mult - v_est_upfi_mult, f_true)\n",
    "           })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aebcc72e-a263-429c-85e0-b55883677bc2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.subplot(2, 2, 1)\n",
    "sb.heatmap(vf_pfi, vmax = 10, vmin = -10, cmap = \"RdBu_r\")\n",
    "plt.title(\"$u_{infer} - v_{infer}$\")\n",
    "plt.subplot(2, 2, 2)\n",
    "sb.heatmap(f_true, vmin = -10, vmax = 10, cmap = \"RdBu_r\")\n",
    "plt.title(\"$u_{true} - v_{true}$\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dd4714a-183d-4b76-a763-168b48c3932e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.subplot(2, 2, 1)\n",
    "sb.heatmap(vf_upfi, vmax = 10, vmin = -10, cmap = \"RdBu_r\")\n",
    "plt.title(\"$u_{infer} - v_{infer}$\")\n",
    "plt.subplot(2, 2, 2)\n",
    "sb.heatmap(f_true, vmin = -10, vmax = 10, cmap = \"RdBu_r\")\n",
    "plt.title(\"$u_{true} - v_{true}$\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dcb44ac-2acd-4d01-acf0-00b81d72a718",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.subplot(2, 2, 1)\n",
    "sb.heatmap(u_est_upfi_mult - v_est_upfi_mult, vmax = 5, vmin = -5, cmap = \"RdBu_r\")\n",
    "plt.title(\"$u_{infer} - v_{infer}$\")\n",
    "plt.subplot(2, 2, 2)\n",
    "sb.heatmap(f_true, vmin = -5, vmax = 5, cmap = \"RdBu_r\")\n",
    "plt.title(\"$u_{true} - v_{true}$\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db3bcbe4-e9b2-4ebe-9502-024ccef9cf53",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.subplot(2, 2, 1)\n",
    "sb.heatmap(u_est_pfi_mult - v_est_pfi_mult, vmax = 5, vmin = -5, cmap = \"RdBu_r\")\n",
    "plt.title(\"$u_{infer} - v_{infer}$\")\n",
    "plt.subplot(2, 2, 2)\n",
    "sb.heatmap(f_true, vmin = -5, vmax = 5, cmap = \"RdBu_r\")\n",
    "plt.title(\"$u_{true} - v_{true}$\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8507118b-db6d-448c-b4af-9a4b9a6cc2cb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import math\n",
    "importlib.reload(models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ade4c36-4db5-4562-a8a6-8997ec2df6b7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "v_upfi_mult.g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ce6e147-4a68-4f42-8630-197693ea12da",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "v_upfi.g_net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0eb56c85-679b-4008-b6e2-0c0c32576046",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (10, 4))\n",
    "ax = fig.add_subplot(1, 2, 1, projection='3d'); ax.view_init(30, -100)\n",
    "p = scatter = ax.scatter(Xs_pca[:, 0], Xs_pca[:, 1], Xs_pca[:, 2], c=data['beta'], cmap='RdBu_r', vmin = -5, vmax = 5, alpha = 0.25, edgecolor = 'k')\n",
    "fig.colorbar(p)\n",
    "ax = fig.add_subplot(1, 2, 2, projection='3d'); ax.view_init(30, -100)\n",
    "with torch.no_grad():\n",
    "    # g_est = v_upfi.g_net(_, torch.tensor(data['x'], dtype = torch.float32).to(device)).cpu().flatten()\n",
    "    g_est = v_upfi_mult.g(_, torch.tensor(data['x'], dtype = torch.float32).to(device)).cpu().flatten()\n",
    "p = scatter = ax.scatter(Xs_pca[:, 0], Xs_pca[:, 1], Xs_pca[:, 2], c=g_est, cmap='RdBu_r', vmin = -5, vmax = 5, alpha = 0.25, edgecolor = 'k')\n",
    "fig.colorbar(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c193ef06-1901-498a-9382-5e615a5dd856",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.scatter(data['beta'], g_est, alpha = 0.1)\n",
    "plt.ylim(-10, 10.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0aeb2e29-34c0-413a-a444-865fe2af4e30",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import scipy as sp\n",
    "sp.stats.pearsonr(data['beta'], g_est)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba54a4dd-cf67-4bd7-8437-b2bf10f1f5c2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "odeint_options = {'method' : 'euler', 'options' : {'step_size' : 0.05}}\n",
    "xs_t_upfi = odeint(lambda t, x: v_upfi(t, x), x0_mass.to(device), ts, **odeint_options).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db6aa324-c05e-41da-838b-d3903fae4df5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "xs_t = xs_t_upfi\n",
    "plt.figure(figsize = (3, 3))\n",
    "plt.plot(ts.cpu(), [x.shape[0] / X[0].shape[0] for x in X], label = \"True\")\n",
    "plt.plot(ts.cpu(), [xs_t[i, :, 0].exp().sum().item() for i in range(xs_t.shape[0])], label = \"Fit\")\n",
    "plt.title(\"Total system mass\")\n",
    "plt.legend(); plt.xlabel(\"t\"); plt.ylabel(\"Mass\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b10248f4-a350-4011-9b94-545bd6592801",
   "metadata": {},
   "outputs": [],
   "source": [
    "sb.clustermap(data['centroids'], figsize = (3, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5b18c21-5d43-49f7-9033-eee78009546f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "probs = torch.load(f\"evals/fate_probs_seed_{seed}.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "105f0d4d-d9d2-4268-b27a-fe98a6f0365e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_idx = data['t_idx'] == 8\n",
    "plt.figure(figsize = (10, 7.5))\n",
    "for i in range(4):\n",
    "    plt.subplot(3, 4, i+1)\n",
    "    plt.scatter(Xs_pca[_idx, 0], Xs_pca[_idx, 1], c = data['probs'][_idx, i], alpha = 0.3)\n",
    "    plt.title(\"True\"); plt.axis(\"off\")\n",
    "for i in range(4):\n",
    "    plt.subplot(3, 4, 4+i+1)\n",
    "    plt.scatter(Xs_pca[_idx, 0], Xs_pca[_idx, 1], c = probs['probs_upfi'][_idx, i], alpha = 0.3)\n",
    "    plt.title(\"UPFI\"); plt.axis(\"off\")\n",
    "for i in range(4):\n",
    "    plt.subplot(3, 4, 8+i+1)\n",
    "    plt.scatter(Xs_pca[_idx, 0], Xs_pca[_idx, 1], c = probs['probs_pfi'][_idx, i], alpha = 0.3)\n",
    "    plt.title(\"PFI\"); plt.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4db18693-c2d0-436f-a654-5321de6af5f6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "k = 0\n",
    "i = 0\n",
    "fig=plt.figure(figsize = (15, 3))\n",
    "ax = fig.add_subplot(1, 5, 1, projection='3d'); ax.view_init(30, -100)\n",
    "ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=data['probs'][:, i], cmap='bwr', alpha = 0.1, edgecolor = 'k')\n",
    "plt.title(\"Data\")\n",
    "ax = fig.add_subplot(1, 5, 2, projection='3d'); ax.view_init(30, -100)\n",
    "ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=probs['probs_upfi'][:, i], cmap='bwr', alpha = 0.25, edgecolor = 'k')\n",
    "plt.title(\"UPFI Additive\")\n",
    "ax = fig.add_subplot(1, 5, 3, projection='3d'); ax.view_init(30, -100)\n",
    "ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=probs['probs_pfi'][:, i], cmap='bwr', alpha = 0.25, edgecolor = 'k')\n",
    "plt.title(\"PFI Additive\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0409b1df-c6d5-4155-8592-55b824c6b94e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sb\n",
    "plt.figure(figsize = (3, 3))\n",
    "sb.barplot({\"UPFI\" : (probs['probs_upfi'] - data['probs']).abs().sum(-1).mean().item(), \n",
    "            \"UPFI Mult\" : (probs['probs_upfi_mult'] - data['probs']).abs().sum(-1).mean().item(), \n",
    "            \"PFI\" : (probs['probs_pfi'] - data['probs']).abs().sum(-1).mean().item(), \n",
    "            \"PFI Mult\" : (probs['probs_pfi_mult'] - data['probs']).abs().sum(-1).mean().item(), \n",
    "           }, palette = \"tab10\")\n",
    "plt.title(\"Total variation distance: fate probabilities\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa9a22c4-6a8c-425f-81cf-3b292808d820",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (2.5, 2.5))\n",
    "plt.scatter(Xs_pca[:, 0], Xs_pca[:, 1], alpha = 0.05, s = 25, c = data['x'][:, 3], cmap = 'bwr', vmin = -0.15, vmax = 1.15, rasterized = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "357f79a8-664a-438c-99f5-08d6a8c4953d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import plotting\n",
    "pca_op = sk.decomposition.PCA().fit(data_nogrowth['x'])\n",
    "i=0\n",
    "plt.figure(figsize = (10, 2.5))\n",
    "plt.subplot(1, 5, 1)\n",
    "plt.scatter(Xs_pca[:, 0], Xs_pca[:, 1], alpha = 0.05, s = 25, c = data['probs'][:, i], cmap = 'bwr', vmin = -0.15, vmax = 1.15, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(Xs_pca, (f_true @ pca_op.components_.T), color = 'k')\n",
    "plt.title(\"True\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 5, 2)\n",
    "plt.scatter(Xs_pca[:, 0], Xs_pca[:, 1], alpha = 0.05, s = 25, c = probs['probs_upfi'][:, i], cmap = 'bwr', vmin = -0.15, vmax = 1.15, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(Xs_pca, (vf_upfi @ pca_op.components_.T), color = 'k')\n",
    "plt.title(\"UPFI (add.)\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 5, 4)\n",
    "plt.scatter(Xs_pca[:, 0], Xs_pca[:, 1], alpha = 0.05, s = 25, c = probs['probs_pfi'][:, i], cmap = 'bwr', vmin = -0.15, vmax = 1.15, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(Xs_pca, (vf_pfi @ pca_op.components_.T), color = 'k')\n",
    "plt.title(\"PFI (add.)\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 5, 3)\n",
    "plt.scatter(Xs_pca[:, 0], Xs_pca[:, 1], alpha = 0.05, s = 25, c = probs['probs_upfi_mult'][:, i], cmap = 'bwr', vmin = -0.15, vmax = 1.15, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(Xs_pca, ((u_est_upfi_mult-v_est_upfi_mult) @ pca_op.components_.T), color = 'k')\n",
    "plt.title(\"UPFI (mult.)\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 5, 5)\n",
    "plt.scatter(Xs_pca[:, 0], Xs_pca[:, 1], alpha = 0.05, s = 25, c = probs['probs_pfi_mult'][:, i], cmap = \"bwr\", vmin = -0.15, vmax = 1.15, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(Xs_pca, ((u_est_pfi_mult-v_est_pfi_mult) @ pca_op.components_.T), color = 'k')\n",
    "plt.title(\"PFI (mult.)\")\n",
    "plt.axis('off')\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/boolode_HSC_vf_fates_pca2d.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9189af80-7514-4adb-8e2c-a5dd87bb134d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import glob\n",
    "from toolz import interleave\n",
    "import numpy as np\n",
    "files = glob.glob(\"evals/df_energy_distance_paths*.csv\")\n",
    "df = pd.concat([pd.read_csv(f, index_col = 0) for f in files])\n",
    "_df_mean = df.agg(['mean', ])\n",
    "_df_std = df.agg(['std', ])\n",
    "_df_mean_str = _df_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x).reset_index()\n",
    "_df_std_str = _df_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x).reset_index()\n",
    "_df_str = pd.DataFrame({_df_mean_str.columns[i] : _df_mean_str.iloc[:, i].str.cat(_df_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df_mean_str.shape[1])}).iloc[:, 1:]\n",
    "for i, j in enumerate(np.argmin(_df_mean.values, 1)):\n",
    "    _df_str.iloc[i, j] = \"\\\\textbf{\" + _df_str.iloc[i, j] + \"}\"\n",
    "_df_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b1d2226-882c-4eb2-929f-f7ca2fac8973",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "files = glob.glob(\"evals/df_vf_l2_dist*.csv\")\n",
    "df = pd.concat([pd.read_csv(f, index_col = 0) for f in files])**0.5\n",
    "df=df.iloc[:, [2, 3, 0, 1]]\n",
    "_df_mean = df.agg(['mean', ])\n",
    "_df_std = df.agg(['std', ])\n",
    "_df_mean_str = _df_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x).reset_index()\n",
    "_df_std_str = _df_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x).reset_index()\n",
    "_df_str = pd.DataFrame({_df_mean_str.columns[i] : _df_mean_str.iloc[:, i].str.cat(_df_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df_mean_str.shape[1])}).iloc[:, 1:]\n",
    "for i, j in enumerate(np.argmin(_df_mean.values, 1)):\n",
    "    _df_str.iloc[i, j] = \"\\\\textbf{\" + _df_str.iloc[i, j] + \"}\"\n",
    "_df_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3e8494b-16a6-45ad-994d-7408d3724b23",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(_df_str.to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76d7181e-d75f-4e7e-a62a-553d91730ae0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "files = glob.glob(\"evals/df_fate_tv*.csv\")\n",
    "df = pd.concat([pd.read_csv(f, index_col = 0) for f in files])\n",
    "_df_mean = df.groupby('what').agg(['mean', ])\n",
    "_df_std = df.groupby('what').agg(['std', ])\n",
    "_df_mean_str = _df_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df_std_str = _df_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df_str = pd.DataFrame({_df_mean_str.columns[i] : _df_mean_str.iloc[:, i].str.cat(_df_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df_mean_str.shape[1])})\n",
    "for i, j in enumerate(np.argmin(_df_mean.values, 1)):\n",
    "    _df_str.iloc[i, j] = \"\\\\textbf{\" + _df_str.iloc[i, j] + \"}\"\n",
    "_df_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24d85c64-91fa-46d3-b94a-803fadc32738",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(_df_str.reset_index().to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd90e8f7-5884-44fd-8444-5971873b9e12",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c53e3b3-37f3-49da-81fa-03e66e06a49f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "upfi",
   "language": "python",
   "name": "upfi"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
