{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dccfb8f7-b61c-4f2f-ab67-97a17dc1157b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import autograd\n",
    "import autograd.numpy as np\n",
    "import torchdiffeq\n",
    "from torchdiffeq import odeint\n",
    "import geomloss\n",
    "from tqdm import tqdm\n",
    "import importlib\n",
    "import math\n",
    "import torchsde\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.set_default_dtype(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac6244ac-d9bf-446a-ad55-7c5f19b7f23b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sklearn as sk\n",
    "import sklearn.decomposition\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "df = pd.read_csv(f\"../../data/HSC/ExpressionData.csv\", index_col = 0)\n",
    "genes = df.index\n",
    "genes_reord = ['Gata1', 'Gata2', 'Fog1', 'Eklf', 'Fli1', 'Scl', 'Cebpa', 'Pu1', 'cJun', 'EgrNab', 'Gfi1']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6818a73-06b4-4852-aff6-22694d65ac4c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "T = 10\n",
    "beta = 5.5\n",
    "N = 500\n",
    "T = 10\n",
    "c = 0.5\n",
    "seed = 1\n",
    "data = torch.load(f\"../HSC/sim_HSC_N_500_T_10_c_{c}_beta_{beta}.pkl\", weights_only = False)\n",
    "data_nogrowth = torch.load(f\"../HSC/sim_HSC_N_500_T_10_c_{c}_beta_0.pkl\", weights_only = False)\n",
    "dim = data['x'].shape[1]\n",
    "X = [torch.tensor(data['x'][data['t_idx'] == i, :], device = device, dtype = torch.float32) for i in np.sort(np.unique(data['t_idx']))]\n",
    "ts = torch.linspace(0, data['t_final'], len(X), device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "833c77dc-1a61-4c8c-bacb-b8ccb99cb8d1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pca_op = sk.decomposition.PCA()\n",
    "pca_op.fit(data_nogrowth['x'])\n",
    "X_pca = pca_op.transform(data['x'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75d2de5a-b55b-47c6-82c9-268a2000516d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.scatter(X_pca[:, 0], X_pca[:, 1], c = data['t_idx'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "666848ad-8dae-419b-9ef3-67f8960581d4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "import models, utils, train\n",
    "from torch import optim\n",
    "sigmas = torch.linspace(0, -2, 5).exp().to(device)\n",
    "s = models.NCScoreFunc(d = dim, hidden_sizes = [128, 128, 128], activation = torch.nn.ReLU, time_dependent = True).to(device)\n",
    "s.load_state_dict(torch.load(f\"../HSC/weights/params_NCScoreFunc_default_c_{c}_seed_{seed}_final.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7b155b2-ad6e-47e7-9b7f-12e0ed6b535d",
   "metadata": {},
   "outputs": [],
   "source": [
    "P = torch.vstack([torch.eye(2), torch.zeros(dim-2, 2)])\n",
    "s.to(device)\n",
    "samplers = [models.LangevinSampler(lambda x, sigma, _t = torch.scalar_tensor(_s).to(device): s(_t, x, sigma), \n",
    "                      torch.randn(1_000, dim).to(device), \n",
    "                      sigmas = sigmas, dt = 1e-3, n_iter = 1000) for _s in ts]\n",
    "x_sample = torch.vstack([s.sample().cpu()[None, ...] for s in samplers])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f7f50de-40c0-4373-afbb-900c6a3cf39d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (10, 4))\n",
    "ax = fig.add_subplot(1, 2, 1, projection='3d'); ax.view_init(30, -120)\n",
    "scatter = ax.scatter(X_pca[:, 0], X_pca[:, 1], X_pca[:, 2], c=data['t_idx'], cmap='viridis', alpha = 0.25, edgecolor = 'k')\n",
    "ax = fig.add_subplot(1, 2, 2, projection='3d'); ax.view_init(30, -120)\n",
    "x_sample_pca = pca_op.transform(x_sample.reshape(-1, dim))\n",
    "_ts = torch.hstack([torch.full((x_sample.shape[1], ), i) for i in range(x_sample.shape[0])])\n",
    "ax.scatter(x_sample_pca[:, 0], x_sample_pca[:, 1], x_sample_pca[:, 2], alpha = 0.3, s = 1, c = _ts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b8e9249-c4fd-4f78-868a-20102e8ee414",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "importlib.reload(models)\n",
    "importlib.reload(train)\n",
    "D = 0.5**2\n",
    "# D = 0\n",
    "\n",
    "num_iter = 5000\n",
    "\n",
    "v_pfi = models.NGMVectorField(dim, hidden_sizes = [64, 64], GL_reg = 0.03).to(device)\n",
    "v_pfi.load_state_dict(torch.load(f\"../HSC/weights/params_PFI_NGMVectorField_default_seed_{seed}_final.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb00a05a-58a6-40ff-96f8-7f1b265893ec",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "m_ratios = torch.tensor([x.shape[0] / X[0].shape[0] for x in X]).float()\n",
    "v_upfi = models.ODEFlowGrowth(dim, v_mod = models.NGMVectorField, kwargs_v = {'hidden_sizes' : [64, 64], 'GL_reg' : 0.03}, kwargs_g = {'hidden_sizes' : [64, ]}).to(device);\n",
    "v_upfi.load_state_dict(torch.load(f\"../HSC/weights/params_UPFI_ODEFlowGrowth_NGM_default_seed_{seed}_final.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae764537-57cb-4873-9ce5-5107e4a16a89",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_boolODE_reference_network(path, genes):\n",
    "    df = pd.read_csv(path)\n",
    "    n_genes = len(genes)\n",
    "    A_ref = pd.DataFrame(np.zeros((n_genes, n_genes), int), index = genes, columns = genes)\n",
    "    for i in range(df.shape[0]):\n",
    "        _i = df.iloc[i, 1]\n",
    "        _j = df.iloc[i, 0]\n",
    "        _v = {\"+\" : 1, \"-\" : -1}[df.iloc[i, 2]]\n",
    "        A_ref.loc[_i, _j] = _v\n",
    "    return A_ref\n",
    "A_ref = load_boolODE_reference_network(f\"../../data/HSC/refNetwork.csv\", genes)\n",
    "A_ref = A_ref.loc[genes_reord, :].loc[:, genes_reord]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34dd4403-93fd-4d8f-9f7f-7d555b636d89",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.imshow(A_ref, vmin = -1, vmax = 1, cmap = \"RdBu_r\"); plt.gca().invert_yaxis()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b19d4a2b-4c53-43ea-845f-d7f267ba6704",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def maskdiag(A):\n",
    "    return A * (1 - np.eye(A.shape[0]))\n",
    "_x = torch.tensor(data['x'], dtype = torch.float32)\n",
    "A_pfi = v_pfi.net.net.causal_graph(w_threshold=0).T\n",
    "A_upfi = v_upfi.v_net.net.net.causal_graph(w_threshold=0).T\n",
    "A_pfi_jac = torch.vmap(torch.func.jacrev(lambda x: v_pfi(_, x)))(_x.to(device)).mean(0).detach().cpu().T\n",
    "A_upfi_jac = torch.vmap(torch.func.jacrev(lambda x: v_upfi.v_net(_, x)))(_x.to(device)).mean(0).detach().cpu().T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b1d681c-9599-4ba6-91cc-3b3beb512b93",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.patches as patches\n",
    "\n",
    "plt.figure(figsize = (7.5, 2.5))\n",
    "plt.subplot(1, 3, 1)\n",
    "plt.imshow(maskdiag(np.abs(A_ref)), cmap = \"bone_r\", vmin = 0, vmax = 1); plt.gca().invert_yaxis()\n",
    "plt.title(\"True\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 3, 2)\n",
    "ax = plt.gca()\n",
    "im = plt.imshow(maskdiag(A_pfi), cmap = \"bone_r\", vmin = 0, vmax = 0.5); plt.gca().invert_yaxis()\n",
    "for row, col in list(zip(*np.where(maskdiag(np.abs(A_ref))))):\n",
    "    rect = patches.Rectangle((col-0.5, row-0.5), 1, 1, linewidth=2, \n",
    "                             edgecolor='red', facecolor='none')\n",
    "    ax.add_patch(rect)\n",
    "plt.title(\"PFI+NGM\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 3, 3)\n",
    "ax = plt.gca()\n",
    "im = plt.imshow(maskdiag(A_upfi), cmap = \"bone_r\", vmin = 0, vmax = 0.5); plt.gca().invert_yaxis()\n",
    "for row, col in list(zip(*np.where(maskdiag(np.abs(A_ref))))):\n",
    "    rect = patches.Rectangle((col-0.5, row-0.5), 1, 1, linewidth=2, \n",
    "                             edgecolor='red', facecolor='none')\n",
    "    ax.add_patch(rect)\n",
    "plt.axis('off')\n",
    "plt.title(\"UPFI+NGM\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/HSC_NGM_average_causal_graphs.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e452f8e8-0fe6-4c40-8416-e75a23823899",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (7.5, 2.5))\n",
    "plt.subplot(1, 3, 1)\n",
    "plt.imshow(maskdiag(A_ref), cmap = \"RdBu_r\", vmin = -1, vmax = 1); plt.gca().invert_yaxis()\n",
    "plt.title(\"True\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 3, 2)\n",
    "plt.imshow(maskdiag(A_pfi_jac), cmap = \"RdBu_r\", vmin = -5, vmax = 5); plt.gca().invert_yaxis()\n",
    "plt.title(\"PFI+NGM\")\n",
    "plt.axis('off')\n",
    "plt.subplot(1, 3, 3)\n",
    "plt.imshow(maskdiag(A_upfi_jac), cmap = \"RdBu_r\", vmin = -5, vmax = 5); plt.gca().invert_yaxis()\n",
    "plt.axis('off')\n",
    "plt.title(\"UPFI+NGM\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/HSC_NGM_average_jacobians.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bd8911f-5b8c-42f4-a97f-0ce0e4784039",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "x0 = X[0]\n",
    "sde_pfi = models.SDE(lambda t, x: v_pfi(t, x), sigma = D**0.5)\n",
    "sde_upfi = models.SDE(lambda t, x: v_upfi.v_net(t, x), sigma = D**0.5)\n",
    "with torch.no_grad():\n",
    "    xs_t_pfi = torchsde.sdeint(sde_pfi, x0, ts, method = \"euler\").cpu()\n",
    "    xs_t_upfi = torchsde.sdeint(sde_upfi, x0, ts, method = \"euler\").cpu()\n",
    "y_pred_pfi = pca_op.transform(xs_t_pfi.view(-1, dim))\n",
    "y_pred_upfi = pca_op.transform(xs_t_upfi.view(-1, dim))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f08f2355-6a47-45b8-a43c-6d32b43b9576",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (7.5, 2.5))\n",
    "ax = fig.add_subplot(1, 3, 1, projection='3d'); ax.view_init(30, -120)\n",
    "scatter = ax.scatter(X_pca[:, 0], X_pca[:, 1], X_pca[:, 2], c=data['t_idx'], cmap='viridis', alpha = 0.25, edgecolor = 'k')\n",
    "ax.set_xlim(-5, 8); ax.set_ylim(-5, 5); ax.set_zlim(-5, 5); \n",
    "# ax.set_xlabel(\"PCA1\"); ax.set_ylabel(\"PCA2\"); ax.set_zlabel(\"PCA3\")\n",
    "plt.title(\"Data\")\n",
    "ax = fig.add_subplot(1, 3, 2, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(y_pred_pfi[:, 0], y_pred_pfi[:, 1], y_pred_pfi[:, 2], alpha = 0.3, s = 25, c = ts.repeat_interleave(x0.shape[0]).cpu())\n",
    "ax.set_xlim(-5, 8); ax.set_ylim(-5, 5); ax.set_zlim(-5, 5); \n",
    "# ax.set_xlabel(\"PCA1\"); ax.set_ylabel(\"PCA2\"); ax.set_zlabel(\"PCA3\")\n",
    "plt.title(\"PFI+NGM\")\n",
    "ax = fig.add_subplot(1, 3, 3, projection='3d'); ax.view_init(30, -120)\n",
    "ax.scatter(y_pred_upfi[:, 0], y_pred_upfi[:, 1], y_pred_upfi[:, 2], alpha = 0.3, s = 25, c = ts.repeat_interleave(x0.shape[0]).cpu())\n",
    "ax.set_xlim(-5, 8); ax.set_ylim(-5, 5); ax.set_zlim(-5, 5); \n",
    "# ax.set_xlabel(\"PCA1\"); ax.set_ylabel(\"PCA2\"); ax.set_zlabel(\"PCA3\")\n",
    "plt.title(\"UPFI+NGM\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/HSC_NGM_reconstruction_3dplot.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "111dbdb4-575e-46f3-983c-3b20fb355d70",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "y = np.abs(maskdiag(A_ref).values).flatten()\n",
    "plt.figure(figsize = (2.5, 2.5))\n",
    "yhat = maskdiag(np.abs(A_upfi)).flatten()\n",
    "prec, rec, thresh = sk.metrics.precision_recall_curve(y, yhat)\n",
    "avg_prec = sk.metrics.average_precision_score(y, yhat)\n",
    "plt.plot(rec, prec, \n",
    "         label=f'UPFI+NGM, AUPR = {avg_prec:.2f}')\n",
    "plt.fill_between(rec, prec, alpha=0.3)\n",
    "yhat = maskdiag(np.abs(A_pfi)).flatten()\n",
    "prec, rec, thresh = sk.metrics.precision_recall_curve(y, yhat)\n",
    "avg_prec = sk.metrics.average_precision_score(y, yhat)\n",
    "plt.plot(rec, prec, \n",
    "         label=f'PFI+NGM, AUPR = {avg_prec:.2f}',)\n",
    "plt.fill_between(rec, prec, alpha=0.3)\n",
    "# plt.legend(loc = 'right', bbox_to_anchor=(1, 0.5))\n",
    "plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.26),\n",
    "          fancybox=True, shadow=True, ncol=5)\n",
    "plt.xlabel(\"Recall\"); plt.ylabel(\"Precision\")\n",
    "# plt.tight_layout()\n",
    "plt.savefig(\"../../figures/HSC_NGM_AUPRC.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7729379-f513-4ce2-98a8-860c2de02d3f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "y = np.abs(maskdiag(A_ref).values).flatten()\n",
    "plt.figure(figsize = (5, 2.5))\n",
    "plt.subplot(1, 2, 1)\n",
    "yhat = maskdiag(np.abs(A_pfi)).flatten()\n",
    "fpr, tpr, thresh = sk.metrics.roc_curve(y, yhat)\n",
    "auc = sk.metrics.roc_auc_score(y, yhat)\n",
    "plt.plot(fpr,tpr, \n",
    "         color='blue', \n",
    "         label=f'AUROC = {auc:.2f}')\n",
    "plt.fill_between(fpr, tpr, alpha=0.3, color='blue')\n",
    "plt.title(\"PFI+NGM\")\n",
    "plt.legend(loc = 'lower right')\n",
    "plt.xlabel(\"FPR\"); plt.ylabel(\"TPR\")\n",
    "plt.subplot(1, 2, 2)\n",
    "yhat = maskdiag(np.abs(A_upfi)).flatten()\n",
    "fpr, tpr, thresh = sk.metrics.roc_curve(y, yhat)\n",
    "auc = sk.metrics.roc_auc_score(y, yhat)\n",
    "plt.plot(fpr,tpr, \n",
    "         color='blue', \n",
    "         label=f'AUROC = {auc:.2f}')\n",
    "plt.fill_between(fpr, tpr, alpha=0.3, color='blue')\n",
    "plt.title(\"UPFI+NGM\")\n",
    "plt.legend(loc = 'lower right')\n",
    "plt.xlabel(\"FPR\"); plt.ylabel(\"TPR\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/HSC_NGM_AUROC.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6819c481-97e6-4ab1-ab00-0c929e470571",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "files = glob.glob(\"../HSC/evals/auprc_results_ngm_*\")\n",
    "res = [torch.load(f) for f in files]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc29c5a5-b0bd-4a20-a9f1-7d1851fd278d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sb\n",
    "df1 = pd.DataFrame([{x : r[x]['avg_prec'] for x in ['pfi', 'upfi', 'pfi_jac', 'upfi_jac']} for r in res])\n",
    "_df1_mean = df1.agg(['mean', ])\n",
    "_df1_std = df1.agg(['std', ])\n",
    "_df1_mean_str = _df1_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x).reset_index()\n",
    "_df1_std_str = _df1_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x).reset_index()\n",
    "_df1_str = pd.DataFrame({_df1_mean_str.columns[i] : _df1_mean_str.iloc[:, i].str.cat(_df1_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df1_mean_str.shape[1])}).iloc[:, 1:]\n",
    "for i, j in enumerate(np.argmax(_df1_mean.values, 1)):\n",
    "    _df1_str.iloc[i, j] = \"\\\\textbf{\" + _df1_str.iloc[i, j] + \"}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29f67a8f-8a61-46c4-98cc-e20098b5119f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "df2 = pd.DataFrame([{x : r[x]['auroc'] for x in ['pfi', 'upfi', 'pfi_jac', 'upfi_jac']} for r in res])\n",
    "_df2_mean = df2.agg(['mean', ])\n",
    "_df2_std = df2.agg(['std', ])\n",
    "_df2_mean_str = _df2_mean.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x).reset_index()\n",
    "_df2_std_str = _df2_std.applymap(lambda x: f\"{x:.2f}\" if isinstance(x, (int, float)) else x).reset_index()\n",
    "_df2_str = pd.DataFrame({_df2_mean_str.columns[i] : _df2_mean_str.iloc[:, i].str.cat(_df2_std_str.iloc[:, i], sep = \" $\\\\pm$ \") for i in range(_df2_mean_str.shape[1])}).iloc[:, 1:]\n",
    "for i, j in enumerate(np.argmax(_df2_mean.values, 1)):\n",
    "    _df2_str.iloc[i, j] = \"\\\\textbf{\" + _df2_str.iloc[i, j] + \"}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9070d804-4347-4cb1-a822-4bb60409c613",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_df_str = pd.concat([_df1_str, _df2_str], keys = [\"PR\", \"ROC\"]).loc[:, [\"upfi\", \"upfi_jac\", \"pfi\", \"pfi_jac\"]]\n",
    "_df_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01d040b6-c77b-44d0-9e24-1b2cc6c481e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(_df_str.to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "978a2dd8-a422-4a8c-a5cb-c5fe5d02b154",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "upfi",
   "language": "python",
   "name": "upfi"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
