{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1b738a3-87c2-4389-805e-60f74bbd1a71",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scanpy as sc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f7ce53c-46ee-457b-9c23-464da733ce83",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa351eed-ac3a-4480-a98e-d7f6ce8a181d",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata = sc.read_h5ad(\"./a/anonymous./time_series/ebdata.h5ad\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11e64244-f46b-49b3-a01f-070394b9d2f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9405385d-b933-4a40-9a43-64f4e8720089",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.pl.scatter(adata, basis=\"phate\", color=\"leiden\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c350195c-db70-4674-a986-b95e7e030b43",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata.uns[\"iroot\"] = np.flatnonzero(adata.obs[\"leiden\"] == \"7\")[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97027320-cc5b-4947-97c5-82ee3d62fc95",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.tl.dpt(adata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fcc91c1-7fb5-42f4-8b68-0d536b3fd511",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.pl.scatter(adata, basis=\"phate\", color=\"sample_labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df30380f-c5f5-466b-9534-397cba8e88fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.pl.scatter(adata, basis=\"phate\", color=\"dpt_pseudotime\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b080fcd0-dc53-4093-9dfe-e48226339e6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1d PHATE\n",
    "import phate\n",
    "\n",
    "phate_op = phate.PHATE(n_components=1, random_state=42)\n",
    "oned_embed = phate_op.fit_transform(adata.X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b29f3bfb-5faf-4fd8-8379-bf336c7a8fa6",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata.obs[\"1d-phate\"] = oned_embed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "272abeec-c2dc-45f6-bfc5-7befebb5a5de",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata.obsm[\"X_phate_time\"] = np.stack([adata.obs[\"dpt_pseudotime\"], adata.obs[\"1d-phate\"]], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0b0798f-6556-4bc2-830f-5eca103fe41f",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata.obsm[\"X_phate_real_time\"] = np.stack(\n",
    "    [\n",
    "        adata.obs[\"sample_labels\"].cat.codes + np.random.randn(adata.shape[0]) * 0.2,\n",
    "        adata.obs[\"1d-phate\"],\n",
    "    ],\n",
    "    axis=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9b1b9ba-1903-4319-8f58-a9d5f35e56e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.pl.scatter(adata, basis=\"phate\", color=\"dpt_pseudotime\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61c74ba4-bd5d-48ff-9ac1-eb822d8ace70",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.pl.scatter(adata, basis=\"phate\", color=\"leiden\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bb1d277-514e-470b-aa0c-463c86572eba",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.pl.scatter(adata, basis=\"phate_time\", color=\"sample_labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8431de87-0c0c-4108-975a-efa0105d4d91",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.pl.scatter(adata, basis=\"phate_real_time\", color=\"leiden\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8be1f38-834d-42e7-b3f3-b6e3f637b1b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.pl.scatter(adata, basis=\"phate_real_time\", color=\"sample_labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "628a0df4-033f-4925-84a3-6fba46cd27c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata.write_h5ad(\n",
    "    \"./a/anonymous./time_series/ebdata_v2.h5ad\", compression=\"gzip\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5686e2fc-4bd2-4668-a2d2-1850d8aad86d",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata = sc.read_h5ad(\"./a/anonymous./time_series/ebdata_v2.h5ad\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c056160a-01d7-4c7a-b793-cbef54cc19ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "oned = adata.obs[\"1d-phate\"]\n",
    "oned = (oned - oned.mean()) / oned.std()\n",
    "oned\n",
    "adata.obs[\"1d-phate-normalized\"] = oned\n",
    "adata.obsm[\"X_phate_normalized\"] = np.stack(\n",
    "    [\n",
    "        adata.obs[\"1d-phate-normalized\"],\n",
    "        adata.obs[\"sample_labels\"].cat.codes + np.random.randn(adata.shape[0]) * 0.2,\n",
    "    ],\n",
    "    axis=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af9214e2-8ea4-4276-ab22-dcd25f8307b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata.write_h5ad(\n",
    "    \"./a/anonymous./time_series/ebdata_v3.h5ad\", compression=\"gzip\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1054bd68-4f8e-4f38-a90a-0b0fb1dd04ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a730c7cc-8800-4e78-af59-bf9484bb3ad9",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.pl.scatter(adata, basis=\"phate_normalized\", color=\"leiden\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64a951e7-5d23-412c-bb09-e0feeb1aba2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata_sub = adata[~adata.obs[\"leiden\"].isin([\"7\"])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0938ee17-42b2-4e49-80ed-de42782350ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.pl.scatter(adata_sub, basis=\"phate_normalized\", color=\"leiden\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02b420a9-f8d6-44d2-b6b1-72093a064e6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata_sub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49ead8fc-3a88-4581-8a1d-5a71eb6d279f",
   "metadata": {},
   "outputs": [],
   "source": [
    "del adata_sub.uns[\"leiden_colors\"]\n",
    "sc.tl.leiden(adata_sub, resolution=0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a2854e3-f729-480a-ac87-d847b84f65cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.pl.scatter(adata_sub, basis=\"phate_normalized\", color=\"leiden\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58684cc6-ae8d-43e1-b376-eb6c552686e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata_sub.write_h5ad(\n",
    "    \"./a/anonymous./time_series/ebdata_no_day_zero_v3.h5ad\",\n",
    "    compression=\"gzip\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8ab98bc-7da5-485b-b6a4-b9b232ab1f03",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "783f12ca-2ab4-4f1c-8f19-076838364e86",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d543011-b422-425e-92b5-d93f78fd9d79",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = adata.obs[\"1d-phate-normalized\"]\n",
    "t = adata.obs[\"sample_labels\"].cat.codes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2862ccb3-ce8c-44c5-9ca7-88a46b6a3220",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_sub = X[t.isin([3, 4])]\n",
    "t_sub = t[t.isin([3, 4])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79196d7a-0b18-4e1d-afcf-4650b18c95b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "scprep.plot.scatter(X_sub, t_sub + 0.05 * np.random.randn(*t_sub.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce1bff6d-87f9-488f-b3b4-b7b7eefeeea2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "sns.kdeplot(data=adata.obs, x=\"1d-phate-normalized\", hue=\"sample_labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f9b775b-6f7d-4ad7-9a06-933181705eee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ot as pot\n",
    "from scipy.spatial import distance_matrix\n",
    "\n",
    "x0 = X_sub[t_sub == 3].values[:100, None]\n",
    "x1 = X_sub[t_sub == 4].values[:100, None]\n",
    "m = x0.shape[0]\n",
    "n = x1.shape[0]\n",
    "a = np.ones(m) / m\n",
    "b = np.ones(n) / n\n",
    "\n",
    "M = distance_matrix(x0, x1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "346830bf-fc34-4a7f-9476-ac107d63c2e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "final_gamma = 1e-8\n",
    "t, t_list, g_list = pot.regpath.regularization_path(\n",
    "    a, b, M, reg=final_gamma, semi_relaxed=True, itmax=100\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0049bc5-7c32-4df6-8230-9554611e205c",
   "metadata": {},
   "outputs": [],
   "source": [
    "t2 = pot.regpath.compute_transport_plan(1, g_list, t_list).reshape(m, n)\n",
    "if t2.sum() > 0:\n",
    "    t2 = t2 / t2.max()\n",
    "for i in range(m):\n",
    "    for j in range(n):\n",
    "        if t2[i, j] > 0:\n",
    "            plt.plot([[x0[i, 0], 3], [x1[j, 0], 4]], color=\"C2\", alpha=t2[i, j] * 0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f994d9f0-a335-4a14-a826-318912f54f5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "t2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90ae7913-c195-46ac-bb71-21b8cf4cadee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.hist(X_sub[t_sub == 3], bins=100)\n",
    "cax = plt.hist(X_sub[t_sub == 4], bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9d9ff6c-34a6-4d30-ad44-59d63c926adc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b97b4b5-4967-4cca-8173-346274f3cfd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scprep\n",
    "\n",
    "scprep.plot.scatter2d(\n",
    "    adata[adata.obs[\"sample_labels\"].cat.codes.isin([3, 4])].obsm[\"X_phate_normalized\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c061706-172f-45ef-9276-7e8d66985221",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}