{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f89e77d-5ce1-4b4b-93db-e561c90a64e4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.append(\"../../src/gwot_simulator/\")\n",
    "sys.path.append(\"../../src/sim/\")\n",
    "import sim as simulator\n",
    "import util\n",
    "import importlib\n",
    "import pandas as pd\n",
    "\n",
    "import scipy as sp\n",
    "import seaborn as sb\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import sklearn as sk\n",
    "import sklearn.decomposition\n",
    "import umap\n",
    "import torch\n",
    "\n",
    "import CLEBoolODE_HSC\n",
    "import CLEBoolODE_HSC as hsc_cle\n",
    "from CLEBoolODE_HSC import interact_HSC\n",
    "importlib.reload(CLEBoolODE_HSC)\n",
    "\n",
    "Ndim = 11;\n",
    "mvec = 20*np.ones((Ndim,)); \n",
    "kvec = 10*np.ones((Ndim,)); \n",
    "nvec = 10*np.ones((Ndim,));\n",
    "lp = 1;  r = 10;  lx = 5; \n",
    "# c = 0.25;\n",
    "c = 0.5;\n",
    "betamax = 5.5\n",
    "# betamax = 0\n",
    "\n",
    "np.random.seed(0)\n",
    "torch.manual_seed(0)\n",
    "\n",
    "# ic_func = lambda N, d: np.random.normal(5,0.5,(N, d))\n",
    "ic_func = lambda N, d: np.random.gamma(5,0.5,(N, d)) * np.array([0.3, 1, 0.3, 0.3, 0.3, 0.3, 1.0, 1.0, 0.3, 0.3, 0.3])\n",
    "f_func = lambda xold, _: (interact_HSC((r/lp)*xold.T,Ndim,mvec,kvec,nvec) - lx*xold.T).T\n",
    "g_func = lambda xold, _: c * np.sqrt(interact_HSC((r/lp)*xold.T,Ndim,mvec,kvec,nvec) + lx*xold.T).T\n",
    "\n",
    "Gn = {'G1':0,'G2':1,'Fg':2,'E':3,\n",
    "      'Fli':4,'S':5,'Ceb':6,'P':7,\n",
    "      'cJ':8, 'Eg':9,'G':10\n",
    "     }\n",
    "\n",
    "genes = pd.Series({v : k for (k, v) in Gn.items()})\n",
    "\n",
    "# plt.hist(np.random.normal(5, 1, (100, )), alpha = 0.5)\n",
    "# plt.hist(np.random.gamma(5, 0.75, (100, )), alpha = 0.5)\n",
    "beta = lambda x, t: betamax*((np.tanh(1*(x[Gn['E']] - 1.5)) + 1)/2) # *(1-(np.tanh(1*(x[Gn['G2']] - 1.5)) + 1)/2)\n",
    "delta = lambda x, t: 0\n",
    "\n",
    "T = 10\n",
    "N = 500\n",
    "sim = simulator.Simulation(dV = f_func,\n",
    "                     D = g_func, \n",
    "                     birth_death = betamax > 0, \n",
    "                     birth = beta, death = delta, \n",
    "                     N = np.array([N, ] * T),\n",
    "                     T = T, \n",
    "                     d = Ndim, \n",
    "                     t_final = 1.0, \n",
    "                     ic_func = ic_func, \n",
    "                     pool = None)\n",
    "\n",
    "sim_steps = 1_000\n",
    "sim.sample(steps_scale = int(sim_steps/sim.T), trunc = 10_000);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21a4cbe5-fe21-43c8-9507-9ffe8e1aa816",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "paths_gt = sim.sample_trajectory(steps_scale = int(sim_steps/sim.T), N = 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bd5306f-75bd-4752-ba39-6138f6b34a34",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import seaborn as sb\n",
    "plt.figure(figsize = (3, 3))\n",
    "sb.barplot(pd.Series(sim.t_idx).value_counts().sort_index())\n",
    "plt.xlabel(\"Timepoint\"); plt.ylabel(\"Cells\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9059e504-9eb8-42bc-a768-b386d3fd5828",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sb.clustermap(sim.x, figsize = (5, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ce092c8-5bc3-4f63-bdbd-6e82ced7f92c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pca_op = sk.decomposition.PCA()\n",
    "umap_op = umap.UMAP(n_neighbors=50, min_dist = 0.7)\n",
    "X_umap = umap_op.fit_transform(sim.x)\n",
    "X_pca = pca_op.fit_transform(sim.x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c62708a-8cdb-4915-a39d-a85ca68741c1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.subplot(1, 2, 1)\n",
    "plt.scatter(X_pca[:, 0], X_pca[:, 1], c = sim.t_idx)\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.scatter(X_umap[:, 0], X_umap[:, 1], c = sim.t_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af59df09-0a51-4418-85a2-50bbaa617789",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (6, 3))\n",
    "plt.subplot(1, 2, 1)\n",
    "gs = g_func(sim.x, None)\n",
    "z = np.sqrt((gs**2).sum(-1))\n",
    "plt.scatter(X_pca[:, 0], X_pca[:, 1], c = z, cmap = \"viridis\", alpha = 0.25, vmin = np.quantile(z, 0.05), vmax = np.quantile(z, 0.95))\n",
    "plt.colorbar()\n",
    "plt.subplot(1, 2, 2)\n",
    "z = np.array([beta(x, None) for x in sim.x])\n",
    "plt.scatter(X_pca[:, 0], X_pca[:, 1], c = z, cmap = \"viridis\", alpha = 0.25)\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0add90c6-aa9d-43f4-aab1-f067d0544a2f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (8, 6))\n",
    "for (i, g) in enumerate(genes):\n",
    "    plt.subplot(3, 4, i+1)\n",
    "    z = sim.x[:, Gn[g]]\n",
    "    plt.scatter(X_umap[:, 0], X_umap[:, 1], c = z, cmap = \"magma\", alpha = 0.25, s = 1)\n",
    "    plt.title(g)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27cf927e-ccd3-4c0a-b0c0-2e6788868594",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data = {'x' : sim.x, 't_idx' : sim.t_idx, 't_final' : sim.t_final, \n",
    "            'f': f_func(sim.x, _), 'g': g_func(sim.x, _),\n",
    "            'x_paths' : paths_gt.reshape(-1, paths_gt.shape[-1]).reshape(paths_gt.shape),\n",
    "            'beta' : np.array([beta(x, None) for x in sim.x])}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cb5e7d3-1ee7-4dc6-8023-598282e7d130",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "import evals\n",
    "\n",
    "clust = sk.cluster.KMeans(n_clusters = 4)\n",
    "clust.fit(data['x'][data['t_idx'] == sim.T-1])\n",
    "_centroids = torch.tensor(clust.cluster_centers_, dtype = torch.float32)\n",
    "\n",
    "def _get_fate(x):\n",
    "    x0 = torch.tensor(x, dtype = torch.float32)\n",
    "    y = simulator.sde_integrate(sim.dV, sim.D, x0, sim.t_final*(1-i/sim.T), 250, snaps = np.array([250-1, ]), birth_death = False)[0][0]\n",
    "    return torch.tensor(y, dtype = torch.float32)\n",
    "\n",
    "probs = []\n",
    "for i in tqdm(range(T)):\n",
    "    probs.append(evals.get_centroid_probs(torch.tensor(data['x'][data['t_idx'] == i, :]), _get_fate, _centroids, n_sample = 25))\n",
    "probs = torch.vstack(probs)\n",
    "data['probs'] = probs\n",
    "data['centroids'] = _centroids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5226b84b-eb64-4f17-9803-9829ac1b7abb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "k=0\n",
    "fig=plt.figure(figsize = (20, 5))\n",
    "y = pca_op.transform(clust.cluster_centers_)\n",
    "for i in range(clust.n_clusters):\n",
    "    ax = fig.add_subplot(1, 4, i+1, projection='3d', computed_zorder=False); ax.view_init(30, -120)\n",
    "    ax.scatter(X_pca[:, k], X_pca[:, k+1], X_pca[:, k+2], c = probs[:, i], alpha = 0.25, s = 2.5)\n",
    "    ax.scatter(y[[i, ], k], y[[i, ], k+1], y[[i, ], k+2], c = 'r', s = 50, marker = 'x')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b14deb96-38c3-42e5-ac02-38c158ce8667",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig=plt.figure(figsize = (5, 5))\n",
    "z = -torch.xlogy(probs, probs).sum(-1)\n",
    "ax = fig.add_subplot(1, 1, 1, projection='3d', computed_zorder=False); ax.view_init(30, -120)\n",
    "ax.scatter(X_pca[:, k], X_pca[:, k+1], X_pca[:, k+2], c = z, alpha = 0.25, s = 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73189be2-b647-4b20-919a-7d32b644fdf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "torch.save(data, f\"sim_HSC_N_{N}_T_{T}_c_{c}_beta_{betamax}.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "902ec6c7-ecf1-4d68-9bad-feb520c3ed90",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "upfi",
   "language": "python",
   "name": "upfi"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
