{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63007c5f-e0bc-4fc1-9cdb-4398eb1f3d57",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.append(\"../../src/gwot_simulator/\")\n",
    "sys.path.append(\"../../src/sim/\")\n",
    "import sim as simulator\n",
    "import util\n",
    "import importlib\n",
    "importlib.reload(util)\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import CLEBoolODE_BF\n",
    "import CLEBoolODE_BF as bf_cle\n",
    "import torch\n",
    "import sklearn as sk\n",
    "importlib.reload(CLEBoolODE_BF)\n",
    "\n",
    "Ndim = 7;\n",
    "mvec = 20*np.ones((Ndim,)); \n",
    "kvec = 10*np.ones((Ndim,)); \n",
    "nvec = 10*np.ones((Ndim,));\n",
    "lp = 1;  r = 10;  lx = 5; \n",
    "c = 0.5;\n",
    "\n",
    "np.random.seed(0)\n",
    "torch.manual_seed(0)\n",
    "\n",
    "ic_func = lambda N, d: np.random.normal(5,1,(N, d)) * np.array([1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])\n",
    "f_func = lambda xold, _: (CLEBoolODE_BF.interact_BF((r/lp)*xold.T,Ndim,mvec,kvec,nvec) - lx*xold.T).T\n",
    "g_func = lambda xold, _: c * np.sqrt(CLEBoolODE_BF.interact_BF((r/lp)*xold.T,Ndim,mvec,kvec,nvec) + lx*xold.T).T\n",
    "\n",
    "betamax = 5.0\n",
    "# betamax = 0\n",
    "beta = lambda x, t: betamax*((np.tanh(5*(x[Gn['g7']] - 0.7)) + 1)/2)*(1-(np.tanh(5*(x[Gn['g1']] - 0.7)) + 1)/2)\n",
    "delta = lambda x, t: 0\n",
    "Gn = {'g1':0,'g2':1,'g3':2,'g4':3,\n",
    "      'g6':4,'g7':5,'g8':6\n",
    "     }\n",
    "genes = pd.Series({v : k for (k, v) in Gn.items()})\n",
    "\n",
    "# T = 25\n",
    "T = 10\n",
    "N = 500\n",
    "sim = simulator.Simulation(dV = f_func,\n",
    "                     D = g_func, \n",
    "                     birth_death = (betamax > 0), \n",
    "                     birth = beta, death = delta, \n",
    "                     N = np.array([N, ] * T),\n",
    "                     T = T, \n",
    "                     d = Ndim, \n",
    "                     t_final = 1.25, \n",
    "                     ic_func = ic_func, \n",
    "                     pool = None)\n",
    "\n",
    "sim_steps = 1_000\n",
    "sim.sample(steps_scale = int(sim_steps/sim.T), trunc = 5_000);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4bcb434-6d30-472b-97f2-b10f85ace231",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "paths_gt = sim.sample_trajectory(steps_scale = int(sim_steps/sim.T), N = 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c67f9a0-77cd-450d-9648-22762b492a2a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import seaborn as sb\n",
    "sb.heatmap(sim.x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f125d711-f282-4e3f-be4e-131fc4ce9cec",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "[(sim.t_idx == i).sum() for i in range(T)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da8582f9-f512-49c9-b2d6-ba33ddf33256",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt \n",
    "import sklearn as sk\n",
    "import sklearn.decomposition\n",
    "beta = lambda x, t: ((np.tanh(5*(x[Gn['g7']] - 0.7)) + 1)/2)*(1-(np.tanh(5*(x[Gn['g1']] - 0.7)) + 1)/2)\n",
    "z = 0*beta(sim.x.T, _)\n",
    "plt.scatter(sim.x[:, Gn['g4']], sim.x[:, Gn['g6']], c = z, alpha = 0.1, cmap = \"RdBu_r\", vmax = 2.5, vmin = 0, s = 2.5)\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6e70f9e-1cdb-43ec-bade-619e818c01f2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pca_op=sk.decomposition.PCA()\n",
    "X_pca = pca_op.fit_transform(sim.x)\n",
    "X_paths_pca = pca_op.transform(paths_gt.reshape(-1, paths_gt.shape[-1])).reshape(paths_gt.shape)\n",
    "\n",
    "fig=plt.figure(figsize = (5, 5))\n",
    "ax = fig.add_subplot(1, 1, 1, projection='3d'); ax.view_init(30, -120)\n",
    "k=0\n",
    "ax.scatter(X_pca[:, k], X_pca[:, k+1], X_pca[:, k+2], c=z, cmap='viridis', alpha = 0.1, edgecolor = 'k')\n",
    "for i in range(100):\n",
    "    ax.plot(X_paths_pca[i, :, k], X_paths_pca[i, :, k+1], X_paths_pca[i, :, k+2], c='r', alpha = 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f93de77-c047-4670-8c0d-45e2b473fff5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "data = {'x' : sim.x, 'x_raw' : sim.x, 't_idx' : sim.t_idx, 't_final' : sim.t_final, \n",
    "            'x_paths' : paths_gt.reshape(-1, paths_gt.shape[-1]).reshape(paths_gt.shape),\n",
    "            'f': f_func(sim.x, _), 'g': g_func(sim.x, _),\n",
    "            'beta' : np.array([beta(x, None) for x in sim.x])}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a04cee95-4100-4efb-b459-3c44388901b3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Assign fate clusters\n",
    "clust = sk.cluster.KMeans(n_clusters = 2)\n",
    "clust.fit(data['x'][data['t_idx'] == T-1])\n",
    "plt.figure(figsize = (3, 3))\n",
    "plt.scatter(data['x'][:, 4], data['x'][:, 5], alpha = 0.25, s = 30)\n",
    "plt.scatter(clust.cluster_centers_[:, 4], clust.cluster_centers_[:, 5], marker = \"x\", color = \"r\", s = 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "382618f6-2c4a-4f4c-88b6-66d88447f0e0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src/\")\n",
    "import evals\n",
    "importlib.reload(evals)\n",
    "from tqdm import tqdm\n",
    "ts = np.linspace(0, sim.t_final, sim.T)\n",
    "_centroids = torch.tensor(clust.cluster_centers_, dtype = torch.float32)\n",
    "def _get_fate(x):\n",
    "    x0 = torch.tensor(x, dtype = torch.float32)\n",
    "    y = simulator.sde_integrate(sim.dV, sim.D, x0, sim.t_final*(1-i/sim.T), 100, snaps = np.array([99, ]), birth_death = False)[0][0]\n",
    "    return torch.tensor(y, dtype = torch.float32)\n",
    "\n",
    "probs = []\n",
    "for i in tqdm(range(T)):\n",
    "    probs.append(evals.get_centroid_probs(torch.tensor(data['x'][data['t_idx'] == i, :]), _get_fate, _centroids, n_sample = 25))\n",
    "probs = torch.vstack(probs)\n",
    "data['probs'] = probs\n",
    "data['centroids'] = _centroids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f629ac2c-1401-4909-969e-49e73b51d1d8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.scatter(sim.x[:, 4], sim.x[:, 5], c = probs[:, 0], cmap = \"viridis\", alpha = 0.1, s = 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3441d9dc-220e-4b1f-ac9c-86585c6b2cd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(data, f\"sim_BF_beta_{betamax}_N_{N}_T_{T}_c_{c}.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1b3d5ba-ddba-4694-840e-9e54c261f209",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "upfi",
   "language": "python",
   "name": "upfi"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
