{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee2f7a9c-714a-4376-b553-5263436880da",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import anndata as ad\n",
    "import scanpy as sc\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import os\n",
    "import sklearn as sk\n",
    "from sklearn import preprocessing, decomposition\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import seaborn as sb\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.set_default_dtype(torch.float32)\n",
    "c = 0.5\n",
    "beta = 5.0\n",
    "seed = 1\n",
    "\n",
    "data_nogrowth = torch.load(f\"sim_BF_beta_0_N_100_T_10_c_{c}.pkl\", weights_only = False)\n",
    "data = torch.load(f\"sim_BF_beta_{beta}_N_500_T_10_c_{c}.pkl\", weights_only = False)\n",
    "adata = ad.AnnData(data['x'], {\"t_idx\" : data['t_idx']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f71ef0cf-0c32-4754-af32-d65ce32bf8e1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (2.5, 2.5))\n",
    "y = pd.Series(data[\"t_idx\"]).value_counts().sort_index()\n",
    "sb.barplot(y / y[0])\n",
    "plt.xlabel(\"Timepoint\"); plt.ylabel(\"Relative mass\")\n",
    "plt.savefig(\"../../figures/boolode_BF_relativemass_vs_time.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "957792df-01d8-46e9-9acb-fe452890a168",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "scaler_op = sk.preprocessing.StandardScaler().fit(data_nogrowth['x'])\n",
    "ts = torch.tensor(adata.obs.t_idx) \n",
    "Xs = adata.X # scaler_op.transform(adata.X)\n",
    "pca_op = sk.decomposition.PCA().fit(data_nogrowth['x'])\n",
    "Xs_pca = pca_op.transform(Xs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "964ace57-300c-4320-ba60-bfb077042e10",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (3.5, 3.5))\n",
    "ax = fig.add_subplot(1, 1, 1, projection='3d')\n",
    "ax.view_init(45, -120)\n",
    "scatter = ax.scatter(Xs_pca[:, 0], Xs_pca[:, 1], Xs_pca[:, 2], c=ts, cmap='Blues', alpha = 0.25, edgecolor = 'k', rasterized = True)\n",
    "ax.set_xlabel('X')\n",
    "ax.set_ylabel('Y')\n",
    "ax.set_zlabel('Z')\n",
    "plt.axis('off')\n",
    "plt.savefig(\"../../figures/boolode_BF_pca_3d.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60f9a14f-5276-40e2-9520-0bae2103a66d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "import models\n",
    "from torch import optim\n",
    "\n",
    "ts = torch.tensor(np.sort(np.unique(adata.obs.t_idx)), dtype = torch.float32)\n",
    "ts /= ts.max()\n",
    "dim = Xs.shape[1]\n",
    "T = adata.obs.t_idx.max()+1\n",
    "\n",
    "X = [torch.tensor(Xs[adata.obs.t_idx == i, ...], device = device, dtype = torch.float32) for i in range(T)]\n",
    "\n",
    "sigmas = torch.linspace(0, -2, 5, device = device).exp()\n",
    "s = models.NCScoreFunc(d = dim, hidden_sizes = [64, 64, 64], activation = torch.nn.ReLU, time_dependent = True).to(device)\n",
    "s.load_state_dict(torch.load(f\"weights/params_NCScoreFunc_default_beta_{beta}_c_{c}_seed_{seed}_final.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f69d5a8e-061f-4dbc-8dff-e7103dcc01c8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "s.to(device)\n",
    "samplers = [models.LangevinSampler(lambda x, sigma, _t = torch.scalar_tensor(_s).to(device): s(_t, x, sigma), \n",
    "                      torch.randn(1_000, dim).to(device), \n",
    "                      sigmas = sigmas, dt = 1e-3, n_iter = 1000) for _s in ts]\n",
    "x_sample = torch.vstack([s.sample().cpu()[None, ...] for s in samplers])\n",
    "X_all = torch.vstack(X).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19f054ef-076f-443a-965c-3a06673d2158",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "P = torch.vstack([torch.eye(2), torch.zeros(dim-2, 2)])\n",
    "s.cpu()\n",
    "x_min, x_max = (X_all[:, 0]).min()-1, (X_all[:, 0]).max()+1\n",
    "y_min, y_max = (X_all[:, 1]).min()-1, (X_all[:, 1]).max()+1\n",
    "plt.figure(figsize = (10, 5))\n",
    "for i in range(min(len(X), 10)):\n",
    "    x, y = torch.meshgrid(torch.linspace(x_min, x_max, 15), torch.linspace(y_min, y_max, 15))\n",
    "    _x = torch.vstack([x.flatten(), y.flatten()]).T\n",
    "    t = torch.scalar_tensor(ts[i])\n",
    "    plt.subplot(2,5, i+1)\n",
    "    with torch.no_grad():\n",
    "        _v = s(t, torch.tensor(_x @ P.T, dtype = torch.float32), sigmas[-1].cpu()) @ P\n",
    "        # _v = s(t, torch.tensor(_x @ P.T, dtype = torch.float32)) @ P\n",
    "    # plt.contourf(x, y, torch.linalg.norm(_v, dim = 1).reshape(x.shape), levels = 20, cmap = \"bone\");\n",
    "    plt.quiver(_x[:, 0], _x[:, 1], _v[:, 0], _v[:, 1], torch.linalg.norm(_v, dim = 1).reshape(x.shape), cmap='RdBu_r', scale = 1e3)\n",
    "    plt.scatter(X[i].cpu()[:, 0], X[i].cpu()[:, 1], s = 0.1, c = 'r', alpha = 1, zorder = 10)\n",
    "    plt.scatter(x_sample[i, :, 0].cpu(), x_sample[i, :, 1], s = 5, c = 'b', alpha = 0.1)\n",
    "    plt.xlim(x_min, x_max); plt.ylim(y_min, y_max)\n",
    "    plt.title(f\"t = {ts[i]:.2f}\")\n",
    "    plt.xlabel(\"x\"); plt.ylabel(\"x\")\n",
    "s.to(device)\n",
    "plt.suptitle(\"Learned score function\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/boolode_BF_score_validation.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a9735d-baee-455c-b529-1d5728fcbbba",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (7, 3.5))\n",
    "ax = fig.add_subplot(1, 2, 1, projection='3d'); ax.view_init(30, -120)\n",
    "ts = torch.tensor(adata.obs.t_idx) \n",
    "ax.scatter(Xs_pca[:, 0], Xs_pca[:, 1], Xs_pca[:, 2], c=ts, cmap='Blues', alpha = 0.25, edgecolor = 'k', rasterized = True)\n",
    "plt.axis(\"off\")\n",
    "ax = fig.add_subplot(1, 2, 2, projection='3d'); ax.view_init(30, -120)\n",
    "x_sample_pca = pca_op.transform(x_sample.reshape(-1, dim))\n",
    "ts = torch.hstack([torch.full((x_sample.shape[1], ), i) for i in range(x_sample.shape[0])])\n",
    "ax.scatter(x_sample_pca[:, 0], x_sample_pca[:, 1], x_sample_pca[:, 2], alpha = 0.1, s = 10, c = ts, cmap = \"Blues\", edgecolor = 'k', rasterized = True)\n",
    "plt.axis(\"off\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/boolode_BF_score_validation_pca3d.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31e209ab-85a2-4249-a793-5ecfb9ebafce",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "from tqdm import tqdm\n",
    "import geomloss\n",
    "import torchdiffeq\n",
    "from torchdiffeq import odeint\n",
    "import math\n",
    "import utils, losses, train\n",
    "import importlib\n",
    "importlib.reload(models); importlib.reload(utils); importlib.reload(losses); importlib.reload(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3945f53-67cd-48c6-b13e-b435f28064db",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "D = 0.5**2\n",
    "ts = torch.linspace(0, data['t_final'], len(X)).to(device)\n",
    "m_ratios = torch.tensor([x.shape[0] / X[0].shape[0] for x in X])\n",
    "hidden_sizes = [64, 64, 64]\n",
    "hidden_sizes_mult = [64, 64, 64]\n",
    "odeint_options = {'method' : 'euler', 'options' : {'step_size' : 0.05}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84845f69-4875-42a8-9682-3f5c533cae12",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# UPFI_mult\n",
    "v_upfi_mult = models.MultiplicativeNoiseFlowGrowth(dim, lambda t, x, sigma : s(t, x, sigma), D, sigmas[-1], \n",
    "                                                  kwargs_u = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : hidden_sizes_mult}, \n",
    "                                                  kwargs_v = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : hidden_sizes_mult},\n",
    "                                                  kwargs_g = {'time_dependent' : False, 'hidden_sizes' : hidden_sizes_mult[:1]}).to(device)\n",
    "v_upfi_mult.load_state_dict(torch.load(f'weights/params_MULT_UPFI_ODEFlowGrowth_default_beta_{beta}_c_{c}_seed_{seed}_final.pt'))\n",
    "# PFI_mult\n",
    "v_pfi_mult = models.MultiplicativeNoiseFlow(dim, lambda t, x, sigma : s(t, x, sigma), D, sigmas[-1], \n",
    "                                                  kwargs_u = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : hidden_sizes_mult}, \n",
    "                                                  kwargs_v = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : hidden_sizes_mult}).to(device)\n",
    "v_pfi_mult.load_state_dict(torch.load(f'weights/params_MULT_PFI_VectorField_default_beta_{beta}_c_{c}_seed_{seed}_final.pt'))\n",
    "# UPFI\n",
    "v_upfi = models.ODEFlowGrowth(d = dim, kwargs_v = {'hidden_sizes' : hidden_sizes, 'time_dependent' : False}, \n",
    "                                       kwargs_g = {'hidden_sizes' : hidden_sizes[:1], 'time_dependent' : False}).to(device)\n",
    "v_upfi.load_state_dict(torch.load(f'weights/params_UPFI_ODEFlowGrowth_default_beta_{beta}_c_{c}_seed_{seed}_final.pt'))\n",
    "# PFI\n",
    "v_pfi = models.VectorField(d = dim, hidden_sizes = hidden_sizes, time_dependent = True).to(device)\n",
    "v_pfi.load_state_dict(torch.load(f'weights/params_PFI_VectorField_default_beta_{beta}_c_{c}_seed_{seed}_final.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9700d2b-2d4b-4231-912d-3ca10c68a42c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# ODE \n",
    "v_ode = models.ODEFlowGrowthCoupled(d = dim, hidden_sizes = hidden_sizes, time_dependent = True).to(device)\n",
    "v_ode.load_state_dict(torch.load(f'weights/params_ODE_ODEFlowGrowthCoupled_default_beta_{beta}_c_{c}_seed_{seed}_final.pt'))\n",
    "# TIGON\n",
    "v_tigon = models.ODEFlowGrowth(d = dim, kwargs_v = {'hidden_sizes' : hidden_sizes, 'time_dependent' : True}, \n",
    "                                       kwargs_g = {'hidden_sizes' : hidden_sizes, 'time_dependent' : True}).to(device)\n",
    "v_tigon.load_state_dict(torch.load(f'weights/params_TIGON_ODEFlowGrowth_default_beta_{beta}_c_{c}_seed_{seed}_final.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cd499da-6c69-48f4-9b3b-ad1637bebcb3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "x0_mass = utils.sample_batch_upfi(X, m_ratios.to(device), batch_size = 1024, replacement=True, add_noise=True)[0]\n",
    "x0 = x0_mass[..., 1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6e18348-c0c3-4c12-907e-d72d6fedadf7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torchsde\n",
    "import seaborn as sb\n",
    "Xs_pca = pca_op.transform(Xs)\n",
    "\n",
    "v_upfi_mult.to(device); v_pfi_mult.to(device); v_pfi.to(device); v_upfi.to(device);\n",
    "sde_pfi = models.SDE(lambda t, x: v_pfi(t, x), sigma = D**0.5)\n",
    "sde_upfi = models.SDE(lambda t, x: v_upfi(t, x), sigma = torch.cat([torch.tensor([0, ]), torch.full((dim, ), D**0.5)]).to(device))\n",
    "sde_pfi_mult = models.MultiplicativeNoiseSDE(v_pfi_mult.u, v_pfi_mult.v, sigma = D**0.5)\n",
    "sde_upfi_mult = models.MultiplicativeNoiseSDE(v_upfi_mult.u, v_upfi_mult.v, sigma = D**0.5)\n",
    "\n",
    "with torch.no_grad():\n",
    "    xs_t_pfi = torchsde.sdeint(sde_pfi, x0, ts, method = \"euler\").cpu()\n",
    "    xs_t_upfi = torchsde.sdeint(sde_upfi, x0_mass, ts, method = \"euler\").cpu()[..., 1:]\n",
    "    xs_t_pfi_mult = torchsde.sdeint(sde_pfi_mult, x0, ts, method = \"euler\").cpu()\n",
    "    xs_t_upfi_mult = torchsde.sdeint(sde_upfi_mult, x0, ts, method = \"euler\").cpu()\n",
    "xs_t_pfi_ = xs_t_pfi.reshape(-1, xs_t_pfi.shape[-1]).cpu()\n",
    "xs_t_upfi_ = xs_t_upfi.reshape(-1, xs_t_upfi.shape[-1]).cpu()\n",
    "xs_t_pfi_mult_ = xs_t_pfi_mult.reshape(-1, xs_t_pfi_mult.shape[-1]).cpu()\n",
    "xs_t_upfi_mult_ = xs_t_upfi_mult.reshape(-1, xs_t_upfi_mult.shape[-1]).cpu()\n",
    "\n",
    "with torch.no_grad():\n",
    "    xs_t_pfi_pca_ = pca_op.transform(xs_t_pfi.reshape(-1, xs_t_pfi.shape[-1]).cpu())\n",
    "    xs_t_upfi_pca_ = pca_op.transform(xs_t_upfi.reshape(-1, xs_t_upfi.shape[-1]).cpu())\n",
    "    xs_t_pfi_mult_pca_ = pca_op.transform(xs_t_pfi_mult.reshape(-1, xs_t_pfi_mult.shape[-1]).cpu())\n",
    "    xs_t_upfi_mult_pca_ = pca_op.transform(xs_t_upfi_mult.reshape(-1, xs_t_upfi_mult.shape[-1]).cpu())\n",
    "xs_t_pfi_pca = xs_t_pfi_pca_.reshape(xs_t_pfi.shape)\n",
    "xs_t_upfi_pca = xs_t_upfi_pca_.reshape(xs_t_upfi.shape)\n",
    "xs_t_pfi_mult_pca = xs_t_pfi_mult_pca_.reshape(xs_t_pfi_mult.shape)\n",
    "xs_t_upfi_mult_pca = xs_t_upfi_mult_pca_.reshape(xs_t_upfi_mult.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2af11ff9-e8ba-466e-a283-d4e260ddda46",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "k=0\n",
    "fig=plt.figure(figsize = (5, 5))\n",
    "ax = fig.add_subplot(1, 1, 1, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=data['x'][:, 4], cmap='viridis', alpha = 0.1, edgecolor = 'k')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "009280f6-2966-4e75-8282-66596e97352b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "X_paths_pca = pca_op.transform(data['x_paths'].reshape(-1, 7)).reshape(data['x_paths'].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88c8b6ca-ce91-491b-9bfa-de2df48a3f47",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "k = 0\n",
    "fig=plt.figure(figsize = (15, 3.5))\n",
    "ax = fig.add_subplot(1, 5, 1, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=adata.obs.t_idx, cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "for i in range(25):\n",
    "    ax.plot(X_paths_pca[i, :, k], X_paths_pca[i, :, k+1], X_paths_pca[i, :, k+2], c = 'r', alpha = 0.3, zorder = 100)\n",
    "ax.set_xlim(-5, 5); ax.set_ylim(-5, 5); ax.set_zlim(-2, 6)\n",
    "plt.title(\"Data\"); plt.axis('off')\n",
    "ax = fig.add_subplot(1, 5, 2, projection='3d',); ax.view_init(30, -120)\n",
    "ax.scatter(xs_t_upfi_pca_[:, k], xs_t_upfi_pca_[:, k+1], xs_t_upfi_pca_[:, k+2], c=ts.repeat_interleave(xs_t_upfi.shape[1]).cpu(), cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "for i in range(25):\n",
    "    ax.plot(xs_t_upfi_pca[:, i, k], xs_t_upfi_pca[:, i, k+1], xs_t_upfi_pca[:, i, k+2], c = 'r', alpha = 0.3, zorder = 100)\n",
    "plt.title(\"UPFI Additive\"); plt.axis('off')\n",
    "ax.set_xlim(-5, 5); ax.set_ylim(-5, 5); ax.set_zlim(-2, 6)\n",
    "ax = fig.add_subplot(1, 5, 4, projection='3d', ); ax.view_init(30, -120)\n",
    "ax.scatter(xs_t_pfi_pca_[:, k], xs_t_pfi_pca_[:, k+1], xs_t_pfi_pca_[:, k+2], c=ts.repeat_interleave(xs_t_pfi.shape[1]).cpu(), cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "for i in range(25):\n",
    "    ax.plot(xs_t_pfi_pca[:, i, k], xs_t_pfi_pca[:, i, k+1], xs_t_pfi_pca[:, i, k+2], c = 'r', alpha = 0.3, zorder = 100)\n",
    "plt.title(\"PFI Additive\"); plt.axis('off')\n",
    "ax.set_xlim(-5, 5); ax.set_ylim(-5, 5); ax.set_zlim(-2, 6)\n",
    "ax = fig.add_subplot(1, 5, 3, projection='3d', ); ax.view_init(30, -120)\n",
    "ax.scatter(xs_t_upfi_mult_pca_[:, k], xs_t_upfi_mult_pca_[:, k+1], xs_t_upfi_mult_pca_[:, k+2], c=ts.repeat_interleave(xs_t_upfi_mult.shape[1]).cpu(), cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "for i in range(25):\n",
    "    ax.plot(xs_t_upfi_mult_pca[:, i, k], xs_t_upfi_mult_pca[:, i, k+1], xs_t_upfi_mult_pca[:, i, k+2], c = 'r', alpha = 0.3, zorder = 100)\n",
    "plt.title(\"UPFI Mult\"); plt.axis('off')\n",
    "ax.set_xlim(-5, 5); ax.set_ylim(-5, 5); ax.set_zlim(-2, 6)\n",
    "ax = fig.add_subplot(1, 5, 5, projection='3d', ); ax.view_init(30, -120)\n",
    "ax.scatter(xs_t_pfi_mult_pca_[:, k], xs_t_pfi_mult_pca_[:, k+1], xs_t_pfi_mult_pca_[:, k+2], c=ts.repeat_interleave(xs_t_pfi_mult.shape[1]).cpu(), cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "for i in range(25):\n",
    "    ax.plot(xs_t_pfi_mult_pca[:, i, k], xs_t_pfi_mult_pca[:, i, k+1], xs_t_pfi_mult_pca[:, i, k+2], c = 'r', alpha = 0.3, zorder = 100)\n",
    "plt.title(\"PFI Mult\"); plt.axis('off')\n",
    "ax.set_xlim(-5, 5); ax.set_ylim(-5, 5); ax.set_zlim(-2, 6)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/boolode_BF_sample_paths_pca3d_coarse.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e3a3499-6ac4-4f09-a6cc-43725db75f41",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import evals\n",
    "importlib.reload(evals)\n",
    "sb.barplot({'UPFI (additive)' : evals.energy_distance_paths(xs_t_upfi.permute((1, 0, 2)).numpy(), data['x_paths']),\n",
    "            'UPFI (mult.)' : evals.energy_distance_paths(xs_t_upfi_mult.permute((1, 0, 2)).numpy(), data['x_paths']),\n",
    "             'PFI (additive)' : evals.energy_distance_paths(xs_t_pfi.permute((1, 0, 2)).numpy(), data['x_paths']), \n",
    "             'PFI (mult.)' : evals.energy_distance_paths(xs_t_pfi_mult.permute((1, 0, 2)).numpy(), data['x_paths']), \n",
    "           }, palette = 'tab10')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93b8fade-b2a1-4df8-a141-8a37709f069f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "probs = torch.load(f\"evals/fate_probs_seed_{seed}.pkl\")\n",
    "k = 0\n",
    "fig=plt.figure(figsize = (15, 3.5))\n",
    "ax = fig.add_subplot(1, 5, 1, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=data['probs'][:, 0], cmap='bwr', alpha = 0.5, s = 1, rasterized = True)\n",
    "plt.axis('off')\n",
    "plt.title(\"Data\")\n",
    "ax = fig.add_subplot(1, 5, 2, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=probs['probs_upfi'][:, 0], cmap='bwr', alpha = 0.5, s = 1, rasterized = True)\n",
    "plt.axis('off')\n",
    "plt.title(\"UPFI Additive\")\n",
    "ax = fig.add_subplot(1, 5, 4, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=probs['probs_pfi'][:, 0], cmap='bwr', alpha = 0.5, s = 1, rasterized = True)\n",
    "plt.axis('off')\n",
    "plt.title(\"PFI Additive\")\n",
    "ax = fig.add_subplot(1, 5, 3, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=probs['probs_upfi_mult'][:, 0], cmap='bwr', alpha = 0.5, s = 1, rasterized = True)\n",
    "plt.axis('off')\n",
    "plt.title(\"UPFI Mult.\")\n",
    "ax = fig.add_subplot(1, 5, 5, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=probs['probs_pfi_mult'][:, 0], cmap='bwr', alpha = 0.5, s = 1, rasterized = True)\n",
    "plt.axis('off')\n",
    "plt.title(\"PFI Mult.\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/boolode_BF_fates_pca3d.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0590a60e-3a66-4d27-b510-2a7786faac7f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (3, 3))\n",
    "sb.barplot({\"UPFI\" : np.corrcoef(data['probs'][:, 0], probs['probs_upfi'][:, 0])[1, 0],\n",
    "            \"PFI\" : np.corrcoef(data['probs'][:, 0], probs['probs_pfi'][:, 0])[1, 0],\n",
    "            \"UPFI Mult\" : np.corrcoef(data['probs'][:, 0], probs['probs_upfi_mult'][:, 0])[1, 0],\n",
    "            \"PFI Mult\" : np.corrcoef(data['probs'][:, 0], probs['probs_pfi_mult'][:, 0])[1, 0],\n",
    "           }, palette = \"tab10\")\n",
    "plt.title(\"Fate probability correlation\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01cc13cd-71ee-4614-81aa-7255eecfc9f0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_ts = torch.linspace(0, data['t_final'], 25)\n",
    "with torch.no_grad():\n",
    "    # fine time grid \n",
    "    xs_t_pfi = torchsde.sdeint(sde_pfi, x0, _ts, method = \"euler\").cpu()\n",
    "    xs_t_upfi = torchsde.sdeint(sde_upfi, x0_mass, _ts, method = \"euler\").cpu()[..., 1:]\n",
    "    xs_t_pfi_mult = torchsde.sdeint(sde_pfi_mult, x0, _ts, method = \"euler\").cpu()\n",
    "    xs_t_upfi_mult = torchsde.sdeint(sde_upfi_mult, x0, _ts, method = \"euler\").cpu()\n",
    "    # coarse time grid \n",
    "    xs_t_pfi_coarse = torchsde.sdeint(sde_pfi, x0, ts, method = \"euler\").cpu()\n",
    "    xs_t_upfi_coarse = torchsde.sdeint(sde_upfi, x0_mass, ts, method = \"euler\").cpu()[..., 1:]\n",
    "    xs_t_pfi_mult_coarse = torchsde.sdeint(sde_pfi_mult, x0, ts, method = \"euler\").cpu()\n",
    "    xs_t_upfi_mult_coarse = torchsde.sdeint(sde_upfi_mult, x0, ts, method = \"euler\").cpu()\n",
    "xs_t_pfi_ = xs_t_pfi.reshape(-1, xs_t_pfi.shape[-1]).cpu()\n",
    "xs_t_upfi_ = xs_t_upfi.reshape(-1, xs_t_upfi.shape[-1]).cpu()\n",
    "xs_t_pfi_mult_ = xs_t_pfi_mult.reshape(-1, xs_t_pfi_mult.shape[-1]).cpu()\n",
    "xs_t_upfi_mult_ = xs_t_upfi_mult.reshape(-1, xs_t_upfi_mult.shape[-1]).cpu()\n",
    "# \n",
    "xs_t_pfi_coarse_ = xs_t_pfi_coarse.reshape(-1, xs_t_pfi_coarse.shape[-1]).cpu()\n",
    "xs_t_upfi_coarse_ = xs_t_upfi_coarse.reshape(-1, xs_t_upfi_coarse.shape[-1]).cpu()\n",
    "xs_t_pfi_mult_coarse_ = xs_t_pfi_mult_coarse.reshape(-1, xs_t_pfi_mult_coarse.shape[-1]).cpu()\n",
    "xs_t_upfi_mult_coarse_ = xs_t_upfi_mult_coarse.reshape(-1, xs_t_upfi_mult_coarse.shape[-1]).cpu()\n",
    "\n",
    "with torch.no_grad():\n",
    "    # fine time grid \n",
    "    xs_t_pfi_pca_ = pca_op.transform(xs_t_pfi.reshape(-1, xs_t_pfi.shape[-1]).cpu())\n",
    "    xs_t_upfi_pca_ = pca_op.transform(xs_t_upfi.reshape(-1, xs_t_upfi.shape[-1]).cpu())\n",
    "    xs_t_pfi_mult_pca_ = pca_op.transform(xs_t_pfi_mult.reshape(-1, xs_t_pfi_mult.shape[-1]).cpu())\n",
    "    xs_t_upfi_mult_pca_ = pca_op.transform(xs_t_upfi_mult.reshape(-1, xs_t_upfi_mult.shape[-1]).cpu())\n",
    "    # coarse time grid \n",
    "    xs_t_pfi_pca_coarse_ = pca_op.transform(xs_t_pfi_coarse.reshape(-1, xs_t_pfi_coarse.shape[-1]).cpu())\n",
    "    xs_t_upfi_pca_coarse_ = pca_op.transform(xs_t_upfi_coarse.reshape(-1, xs_t_upfi_coarse.shape[-1]).cpu())\n",
    "    xs_t_pfi_mult_pca_coarse_ = pca_op.transform(xs_t_pfi_mult_coarse.reshape(-1, xs_t_pfi_mult_coarse.shape[-1]).cpu())\n",
    "    xs_t_upfi_mult_pca_coarse_ = pca_op.transform(xs_t_upfi_mult_coarse.reshape(-1, xs_t_upfi_mult_coarse.shape[-1]).cpu())\n",
    "    \n",
    "xs_t_pfi_pca = xs_t_pfi_pca_.reshape(xs_t_pfi.shape)\n",
    "xs_t_upfi_pca = xs_t_upfi_pca_.reshape(xs_t_upfi.shape)\n",
    "xs_t_pfi_mult_pca = xs_t_pfi_mult_pca_.reshape(xs_t_pfi_mult.shape)\n",
    "xs_t_upfi_mult_pca = xs_t_upfi_mult_pca_.reshape(xs_t_upfi_mult.shape)\n",
    "# \n",
    "xs_t_pfi_pca_coarse = xs_t_pfi_pca_coarse_.reshape(xs_t_pfi_coarse.shape)\n",
    "xs_t_upfi_pca_coarse = xs_t_upfi_pca_coarse_.reshape(xs_t_upfi_coarse.shape)\n",
    "xs_t_pfi_mult_pca_coarse = xs_t_pfi_mult_pca_coarse_.reshape(xs_t_pfi_mult_coarse.shape)\n",
    "xs_t_upfi_mult_pca_coarse = xs_t_upfi_mult_pca_coarse_.reshape(xs_t_upfi_mult_coarse.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58c07c81-d98a-466a-b5a3-323e58da027c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "k = 0\n",
    "fig=plt.figure(figsize = (15, 3.5))\n",
    "ax = fig.add_subplot(1, 5, 1, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=adata.obs.t_idx, cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "_data = torch.load(\"sim_BF_beta_0_N_500_T_25_c_0.5.pkl\")\n",
    "_xs_t_pca = pca_op.transform(_data['x_paths'].reshape(-1, _data['x_paths'].shape[-1])).reshape(_data['x_paths'].shape)\n",
    "for i in range(25):\n",
    "    ax.plot(_xs_t_pca[i, :, k], _xs_t_pca[i, :, k+1], _xs_t_pca[i, :, k+2], c = 'r', alpha = 0.3, zorder = 250)\n",
    "ax.set_xlim(-5, 5); ax.set_ylim(-5, 5); ax.set_zlim(-2, 6); plt.axis(\"off\")\n",
    "plt.title(\"Data\")\n",
    "ax = fig.add_subplot(1, 5, 2, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(xs_t_upfi_pca_coarse_[:, k], xs_t_upfi_pca_coarse_[:, k+1], xs_t_upfi_pca_coarse_[:, k+2], \\\n",
    "           c=ts.repeat_interleave(xs_t_upfi_coarse.shape[1]).cpu(), cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "for i in range(25):\n",
    "    ax.plot(xs_t_upfi_pca[:, i, k], xs_t_upfi_pca[:, i, k+1], xs_t_upfi_pca[:, i, k+2], c = 'r', alpha = 0.3, zorder = 250)\n",
    "ax.set_xlim(-5, 5); ax.set_ylim(-5, 5); ax.set_zlim(-2, 6); plt.axis(\"off\")\n",
    "plt.title(\"UPFI Additive\")\n",
    "ax = fig.add_subplot(1, 5, 4, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(xs_t_pfi_pca_coarse_[:, k], xs_t_pfi_pca_coarse_[:, k+1], xs_t_pfi_pca_coarse_[:, k+2], \\\n",
    "           c=ts.repeat_interleave(xs_t_pfi_coarse.shape[1]).cpu(), cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "for i in range(25):\n",
    "    ax.plot(xs_t_pfi_pca[:, i, k], xs_t_pfi_pca[:, i, k+1], xs_t_pfi_pca[:, i, k+2], c = 'r', alpha = 0.3, zorder = 250)\n",
    "ax.set_xlim(-5, 5); ax.set_ylim(-5, 5); ax.set_zlim(-2, 6); plt.axis(\"off\")\n",
    "plt.title(\"PFI Additive\")\n",
    "ax = fig.add_subplot(1, 5, 3, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(xs_t_upfi_mult_pca_coarse_[:, k], xs_t_upfi_mult_pca_coarse_[:, k+1], xs_t_upfi_mult_pca_coarse_[:, k+2], c=ts.repeat_interleave(xs_t_upfi_mult_coarse.shape[1]).cpu(), cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "for i in range(25):\n",
    "    ax.plot(xs_t_upfi_mult_pca[:, i, k], xs_t_upfi_mult_pca[:, i, k+1], xs_t_upfi_mult_pca[:, i, k+2], c = 'r', alpha = 0.3, zorder = 250)\n",
    "ax.set_xlim(-5, 5); ax.set_ylim(-5, 5); ax.set_zlim(-2, 6); plt.axis(\"off\")\n",
    "plt.title(\"UPFI Mult\")\n",
    "ax = fig.add_subplot(1, 5, 5, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(xs_t_pfi_mult_pca_coarse_[:, k], xs_t_pfi_mult_pca_coarse_[:, k+1], xs_t_pfi_mult_pca_coarse_[:, k+2], c=ts.repeat_interleave(xs_t_pfi_mult_coarse.shape[1]).cpu(), cmap='Blues', alpha = 0.1, edgecolor = 'k', rasterized = True)\n",
    "for i in range(25):\n",
    "    ax.plot(xs_t_pfi_mult_pca[:, i, k], xs_t_pfi_mult_pca[:, i, k+1], xs_t_pfi_mult_pca[:, i, k+2], c = 'r', alpha = 0.3, zorder = 250)\n",
    "ax.set_xlim(-5, 5); ax.set_ylim(-5, 5); ax.set_zlim(-2, 6); plt.axis(\"off\")\n",
    "plt.title(\"PFI Mult\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/boolode_BF_sample_paths_pca3d_fine.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bb2e3ef-c482-4f1c-a0e1-b66c7f8cdadc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "f_true = torch.tensor(data['f'], dtype = torch.float32)\n",
    "g_true = torch.tensor(data['g'], dtype = torch.float32)\n",
    "_x = torch.tensor(data['x'], dtype = torch.float32).to(device)\n",
    "with torch.no_grad():\n",
    "    u_est_upfi_mult = v_upfi_mult.u.net(_x).cpu()\n",
    "    v_est_upfi_mult = v_upfi_mult.v.net(_x).cpu()\n",
    "    u_est_pfi_mult = v_pfi_mult.u.net(_x).cpu()\n",
    "    v_est_pfi_mult = v_pfi_mult.v.net(_x).cpu()\n",
    "    vf_pfi = torch.vstack([v_pfi(ts[i], torch.tensor(data['x'][data['t_idx'] == i, :], dtype = torch.float32).to(device)) for i in range(T)]).cpu()\n",
    "    vf_upfi = torch.vstack([v_upfi.v_net(ts[i], torch.tensor(data['x'][data['t_idx'] == i, :], dtype = torch.float32).to(device)) for i in range(T)]).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45d2a802-64e9-45e4-b07a-7ea857fd5127",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import evals, utils\n",
    "importlib.reload(utils)\n",
    "_idx = (data['x'][:, 4] > 1.0)\n",
    "utils.cos_dist(u_est_upfi_mult - v_est_upfi_mult, f_true)[_idx].mean()**0.5, utils.cos_dist(u_est_pfi_mult - v_est_pfi_mult, f_true)[_idx].mean()**0.5, utils.cos_dist(vf_pfi, f_true)[_idx].mean()**0.5, utils.cos_dist(vf_upfi, f_true)[_idx].mean()**0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01d9becf-bcf8-42f4-97aa-2fe427513e1c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig=plt.figure(figsize = (5, 5))\n",
    "ax = fig.add_subplot(1, 1, 1, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=_idx, cmap='viridis', alpha = 0.5, edgecolor = 'k')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "137d28a0-9883-458d-b7be-0a2682f47517",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sb.barplot({\"PFI\" : utils.l2_dist(vf_pfi, f_true),\n",
    "            \"PFI (Mult.)\" : utils.l2_dist(u_est_pfi_mult - v_est_pfi_mult, f_true),\n",
    "            \"UPFI\" : utils.l2_dist(vf_upfi, f_true),\n",
    "            \"UPFI (Mult.)\" : utils.l2_dist(u_est_upfi_mult - v_est_upfi_mult, f_true)\n",
    "           })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd3eccaa-0b02-4e11-92f6-a1d261aa8cdc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig=plt.figure(figsize = (10, 5))\n",
    "ax = fig.add_subplot(1, 2, 1, projection='3d'); ax.view_init(30, -120)\n",
    "z = utils.cos_dist(vf_pfi, f_true)\n",
    "p=ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=z, cmap='viridis', alpha = 0.5, edgecolor = 'k', vmin = 0, vmax = 0.5)\n",
    "fig.colorbar(p)\n",
    "ax = fig.add_subplot(1, 2, 2, projection='3d'); ax.view_init(30, -120)\n",
    "z = utils.cos_dist(vf_upfi, f_true)\n",
    "p=ax.scatter(Xs_pca[:, k], Xs_pca[:, k+1], Xs_pca[:, k+2], c=z, cmap='viridis', alpha = 0.5, edgecolor = 'k', vmin = 0, vmax = 0.5)\n",
    "fig.colorbar(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e496442-3f55-453c-8a8e-f6482354881b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import plotting\n",
    "plt.figure(figsize = (10, 2.5))\n",
    "plt.subplot(1, 5, 1)\n",
    "plt.scatter(Xs_pca[:, 0], Xs_pca[:, 1], alpha = 0.05, s = 25, c = data['probs'][:, 0], cmap = 'bwr', vmin = -0.15, vmax = 1.15, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(Xs_pca, (f_true @ pca_op.components_.T), color = 'k')\n",
    "plt.title(\"True\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 5, 2)\n",
    "plt.scatter(Xs_pca[:, 0], Xs_pca[:, 1], alpha = 0.05, s = 25, c = probs['probs_upfi'][:, 0], cmap = 'bwr', vmin = -0.15, vmax = 1.15, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(Xs_pca, (vf_upfi @ pca_op.components_.T), color = 'k')\n",
    "plt.title(\"UPFI (add.)\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 5, 4)\n",
    "plt.scatter(Xs_pca[:, 0], Xs_pca[:, 1], alpha = 0.05, s = 25, c = probs['probs_pfi'][:, 0], cmap = 'bwr', vmin = -0.15, vmax = 1.15, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(Xs_pca, (vf_pfi @ pca_op.components_.T), color = 'k')\n",
    "plt.title(\"PFI (add.)\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 5, 3)\n",
    "plt.scatter(Xs_pca[:, 0], Xs_pca[:, 1], alpha = 0.05, s = 25, c = probs['probs_upfi_mult'][:, 0], cmap = 'bwr', vmin = -0.15, vmax = 1.15, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(Xs_pca, ((u_est_upfi_mult-v_est_upfi_mult) @ pca_op.components_.T), color = 'k')\n",
    "plt.title(\"UPFI (mult.)\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 5, 5)\n",
    "plt.scatter(Xs_pca[:, 0], Xs_pca[:, 1], alpha = 0.05, s = 25, c = probs['probs_pfi_mult'][:, 0], cmap = \"bwr\", vmin = -0.15, vmax = 1.15, rasterized = True)\n",
    "plotting.plot_stream_vectorfield(Xs_pca, ((u_est_pfi_mult-v_est_pfi_mult) @ pca_op.components_.T), color = 'k')\n",
    "plt.title(\"PFI (mult.)\")\n",
    "plt.axis('off')\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/boolode_BF_vf_fates_pca2d.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0a627aa-38a3-4891-b80a-ef1433c5e908",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (10, 3))\n",
    "plt.subplot(1, 3, 1)\n",
    "with torch.no_grad():\n",
    "    _v = torch.vstack([v_upfi.v_net(t, x).cpu() for (t, x) in zip(ts, X)])\n",
    "sb.heatmap(_v, vmax = 10, vmin = -10, cmap = \"RdBu_r\")\n",
    "plt.title(\"UPFI\")\n",
    "plt.subplot(1, 3, 2)\n",
    "with torch.no_grad():\n",
    "    _v = torch.vstack([v_pfi(t, x).cpu() for (t, x) in zip(ts, X)])\n",
    "sb.heatmap(_v, vmax = 10, vmin = -10, cmap = \"RdBu_r\")\n",
    "plt.title(\"PFI\")\n",
    "plt.subplot(1, 3, 3)\n",
    "sb.heatmap(f_true, vmin = -10, vmax = 10, cmap = \"RdBu_r\")\n",
    "plt.title(\"$u_{true} - v_{true}$\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dcb44ac-2acd-4d01-acf0-00b81d72a718",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.subplot(2, 2, 1)\n",
    "sb.heatmap(u_est_upfi_mult - v_est_upfi_mult, vmax = 10, vmin = -10, cmap = \"RdBu_r\")\n",
    "plt.title(\"$u_{infer} - v_{infer}$\")\n",
    "plt.subplot(2, 2, 2)\n",
    "sb.heatmap(f_true, vmin = -10, vmax = 10, cmap = \"RdBu_r\")\n",
    "plt.title(\"$u_{true} - v_{true}$\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db3bcbe4-e9b2-4ebe-9502-024ccef9cf53",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.subplot(2, 2, 1)\n",
    "sb.heatmap(u_est_pfi_mult - v_est_pfi_mult, vmax = 10, vmin = -10, cmap = \"RdBu_r\")\n",
    "plt.title(\"$u_{infer} - v_{infer}$\")\n",
    "plt.subplot(2, 2, 2)\n",
    "sb.heatmap(f_true, vmin = -10, vmax = 10, cmap = \"RdBu_r\")\n",
    "plt.title(\"$u_{true} - v_{true}$\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8507118b-db6d-448c-b4af-9a4b9a6cc2cb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import math\n",
    "importlib.reload(models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65140862-e87d-4759-a397-806fcbe62aca",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "v_upfi.g_net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0eb56c85-679b-4008-b6e2-0c0c32576046",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (10, 4))\n",
    "ax = fig.add_subplot(1, 2, 1, projection='3d'); ax.view_init(30, -120)\n",
    "p = scatter = ax.scatter(Xs_pca[:, 0], Xs_pca[:, 1], Xs_pca[:, 2], c=data['beta'], cmap='RdBu_r', vmin = -1, vmax = 1, alpha = 0.25, edgecolor = 'k')\n",
    "fig.colorbar(p)\n",
    "ax = fig.add_subplot(1, 2, 2, projection='3d'); ax.view_init(30, -120)\n",
    "with torch.no_grad():\n",
    "    g_est = v_upfi.g_net(_, torch.tensor(data['x'], dtype = torch.float32).to(device)).cpu().flatten()\n",
    "    # g_est = v_upfi_mult.g(_, torch.tensor(data['x'], dtype = torch.float32).to(device)).cpu().flatten()\n",
    "p = scatter = ax.scatter(Xs_pca[:, 0], Xs_pca[:, 1], Xs_pca[:, 2], c=g_est, cmap='RdBu_r', vmin = -5, vmax = 5, alpha = 0.25, edgecolor = 'k')\n",
    "fig.colorbar(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c193ef06-1901-498a-9382-5e615a5dd856",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.scatter(data['beta'], g_est, alpha = 0.1)\n",
    "plt.ylim(-10, 10.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0aeb2e29-34c0-413a-a444-865fe2af4e30",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import scipy as sp\n",
    "sp.stats.pearsonr(data['beta'], g_est)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b50263ad-62c7-4507-89d8-3fca9490b2b3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    jacs = torch.cat([v_upfi.v_net.jacobian(t, x).cpu() for t, x in zip(ts, X)])\n",
    "plt.imshow(jacs.mean(0).T, cmap = \"RdBu_r\", vmin = -10, vmax = 10)\n",
    "plt.gca().invert_yaxis()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6fd6761-ce7f-4be2-bfda-02056942bf7b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    jacs = torch.cat([v_pfi.jacobian(t, x).cpu() for t, x in zip(ts, X)])\n",
    "plt.imshow(jacs.mean(0).T, cmap = \"RdBu_r\", vmin = -10, vmax = 10)\n",
    "plt.gca().invert_yaxis()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0256db03-a5b8-4e7b-a70d-d6ac4394c9f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import glob\n",
    "from toolz import interleave\n",
    "import numpy as np\n",
    "files = glob.glob(\"evals/df_energy_distance_paths*.csv\")\n",
    "df = pd.concat([pd.read_csv(f, index_col = 0) for f in files])\n",
    "_df_mean = df.agg(['mean', ])\n",
    "_df_std = df.agg(['std', ])\n",
    "_df_mean_str = _df_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x).reset_index()\n",
    "_df_std_str = _df_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x).reset_index()\n",
    "_df_str = pd.DataFrame({_df_mean_str.columns[i] : _df_mean_str.iloc[:, i].str.cat(_df_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df_mean_str.shape[1])}).iloc[:, 1:]\n",
    "for i, j in enumerate(np.argmin(_df_mean.values, 1)):\n",
    "    _df_str.iloc[i, j] = \"\\\\textbf{\" + _df_str.iloc[i, j] + \"}\"\n",
    "_df_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31ad3499-b643-419c-8c23-3dff2a7e99c2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(_df_str.to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdfa4f58-b4a0-44d7-9bd1-f18fdc4ef7e0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "files = glob.glob(\"evals/df_vf_l2_dist*.csv\")\n",
    "df = pd.concat([pd.read_csv(f, index_col = 0) for f in files])**0.5\n",
    "df = df.iloc[:, [2, 3, 0, 1]]\n",
    "_df_mean = df.agg(['mean', ])\n",
    "_df_std = df.agg(['std', ])\n",
    "_df_mean_str = _df_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x).reset_index()\n",
    "_df_std_str = _df_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x).reset_index()\n",
    "_df_str = pd.DataFrame({_df_mean_str.columns[i] : _df_mean_str.iloc[:, i].str.cat(_df_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df_mean_str.shape[1])}).iloc[:, 1:]\n",
    "for i, j in enumerate(np.argmin(_df_mean.values, 1)):\n",
    "    _df_str.iloc[i, j] = \"\\\\textbf{\" + _df_str.iloc[i, j] + \"}\"\n",
    "print(_df_str.to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a040374-25b5-495b-865f-b0c9773552b5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "files = glob.glob(\"evals/df_fate_pearsonr*.csv\")\n",
    "df = pd.concat([pd.read_csv(f, index_col = 0) for f in files])\n",
    "_df_mean = df.groupby('what').agg(['mean', ])\n",
    "_df_std = df.groupby('what').agg(['std', ])\n",
    "_df_mean_str = _df_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df_std_str = _df_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x)\n",
    "_df_str = pd.DataFrame({_df_mean_str.columns[i] : _df_mean_str.iloc[:, i].str.cat(_df_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df_mean_str.shape[1])})\n",
    "for i, j in enumerate(np.argmax(_df_mean.values, 1)):\n",
    "    _df_str.iloc[i, j] = \"\\\\textbf{\" + _df_str.iloc[i, j] + \"}\"\n",
    "print(_df_str.reset_index().to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a305bf8e-8e64-41e0-8077-e43509a1ec21",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f5f86fd-689d-4d8e-a7d2-2c887e1652a7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b1f4be7-56bf-4106-9806-b7fca77c38ee",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "upfi",
   "language": "python",
   "name": "upfi"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
