{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76226c81-fe7d-4339-8577-34ed9519345e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "import models\n",
    "from torch import optim\n",
    "import torchdiffeq\n",
    "from torchdiffeq import odeint\n",
    "import geomloss\n",
    "from tqdm import tqdm\n",
    "device = torch.device('cuda:0')\n",
    "\n",
    "import sklearn as sk\n",
    "import sklearn.decomposition\n",
    "import sklearn.preprocessing\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import importlib\n",
    "importlib.reload(models)\n",
    "\n",
    "import dynamo as dyn\n",
    "import scanpy as sc\n",
    "import anndata as ad\n",
    "import pandas as pd\n",
    "\n",
    "np.random.seed(0)\n",
    "torch.manual_seed(0)\n",
    "adata = dyn.read_h5ad('../../data/larry/invitro_after_cell_velocities.h5ad')\n",
    "\n",
    "df_meta = pd.read_csv(\"data/GSM4185642_stateFate_inVitro_metadata.txt\", sep = \"\\t\")\n",
    "df_meta = df_meta.set_index(df_meta.iloc[:, 1].str.split(\"-\").str[0].str.cat(df_meta.iloc[:, 1].str.split(\"-\").str[1]).str.cat(df_meta.iloc[:, 0], sep = \"-\"))\n",
    "df_meta.loc[:, \"neu_mon\"] = 0\n",
    "df_meta.loc[:, \"neu_mon\"].iloc[pd.read_csv(\"data/stateFate_inVitro_neutrophil_monocyte_trajectory.txt\").values[:,0]]=1\n",
    "adata.obs.loc[:, \"neu_mon\"] = df_meta.loc[adata.obs.index].neu_mon\n",
    "sc.pl.scatter(adata, basis = \"spring\", color = \"neu_mon\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1d18a7c-684f-442a-ae8c-01b0de3f65d6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from scipy.io import mmread\n",
    "X_clone = mmread(\"data/GSM4185642_stateFate_inVitro_clone_matrix.mtx\").tocsr()\n",
    "clone_ids = np.argmax(X_clone, 1)\n",
    "clone_ids[X_clone.sum(-1) == 0] = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "396209ab-7e25-457d-9b60-6c2a4cd12a2e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_cloneid = pd.DataFrame({'cloneid' : np.asarray(clone_ids).flatten()}, index = df_meta.index, )\n",
    "adata.obs.loc[:, \"cloneid\"] = df_cloneid.loc[adata.obs.index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea045106-24f4-4680-ac56-918272de8aa1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "adata_subset = adata[adata.obs.neu_mon == 1, :].copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "873fdd04-dcf3-432f-a7e5-6cbefbd947cd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dyn.pl.streamline_plot(adata_subset, color='Cell type annotation', basis='spring')\n",
    "dyn.pl.streamline_plot(adata_subset, color='Cell type annotation', basis='umap')\n",
    "dyn.pl.streamline_plot(adata_subset, color='Cell type annotation', basis='pca')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cd6fc34-cb0c-47f9-b0aa-e44b72a76ac8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sc.pl.scatter(adata_subset, basis = \"spring\", color = [\"neu_mon\", \"Time point\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6449cc5-17f5-4bc8-aac7-4c2f627daa47",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "t_idx = np.array([{2 : 0, 4 : 1, 6 : 2}[x] for x in adata_subset.obs['Time point'].values.astype(int)])\n",
    "adata_subset = adata_subset[np.argsort(t_idx), :]\n",
    "t_idx = np.sort(t_idx)\n",
    "ts = torch.linspace(0, 1, 3)\n",
    "dim = 10\n",
    "data = {'x' : torch.vstack([torch.tensor(adata_subset.obsm[\"X_pca\"][t_idx == i, :][:, :dim], dtype = torch.float32) for i in np.sort(np.unique(t_idx))]),\n",
    "        'v' : torch.vstack([torch.tensor(adata_subset.obsm[\"velocity_pca\"][t_idx == i, :][:, :dim], dtype = torch.float32) for i in np.sort(np.unique(t_idx))]),\n",
    "        'x_spring' : adata_subset.obsm[\"X_spring\"],\n",
    "        'celltype' : adata_subset.obs.loc[:, \"Cell type annotation\"],\n",
    "        't_idx' : t_idx, 't_final' : 1.0, 'id' : list(adata_subset.obs.index), \n",
    "        'cloneid' : adata_subset.obs.cloneid}\n",
    "torch.save(data, \"data_pca.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23309cc5-a595-4494-9d64-dad3be59151b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "## Sanity check\n",
    "#_data = torch.load(\"data_pca.pkl\")\n",
    "#_df = pd.DataFrame(_data['x'], index = _data['id'])\n",
    "#_df.loc[adata_subset.obs.index, :].values - adata_subset.obsm[\"X_pca\"][:, range(10)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b937a688-5373-4ed2-a7d5-0ed2c481e175",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# _df = pd.DataFrame(_data['v'], index = _data['id'])\n",
    "# np.linalg.norm(_df.loc[adata_subset.obs.index, :].values - adata_subset.obsm[\"velocity_pca\"][:, range(10)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ea9544e-ca4a-4e9a-9764-2f4590de0571",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.tl.score_genes_cell_cycle(adata_subset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36ad84f0-6870-4252-a705-af37aa7bcb24",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "upfi",
   "language": "python",
   "name": "upfi"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
