{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=1\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.lines as mlines\n",
    "from utils.utils import get_path\n",
    "from utils.fig_utils import dataset_to_print\n",
    "from vis_utils.plot import plot_scatter\n",
    "from vis_utils.utils import load_dict\n",
    "from vis_utils.loaders import load_dataset, load_small_cc_dataset\n",
    "import os\n",
    "import numpy as np"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "67c27c4d9699541c"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "style_file = \"utils.style\"\n",
    "plt.style.use(style_file)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "2dd22228392448f4"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "root_path = get_path(\"data\")\n",
    "fig_path = os.path.join(root_path, \"figures\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "c6bcdccb68e8461d"
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Fig 2D embeddings of single cell datasets"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "fae94176e32bf90c"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# load embeddings for all single cell dataset\n",
    "data = {}\n",
    "\n",
    "\n",
    "# malaria\n",
    "dataset = \"mca_ss2\"\n",
    "_, y, _, _, d = load_dataset(root_path, dataset, 15)\n",
    "\n",
    "umap_file_name = f\"{dataset}_umap_correlation_k_10_min_dist_1.0_spread_2_seed_2_init_spectral.pkl\"\n",
    "embd = load_dict(os.path.join(root_path, dataset, umap_file_name)).embedding_\n",
    "\n",
    "data[\"mca_ss2\"] = {\"embd\": embd,\n",
    "                   \"y\": y,\n",
    "                   \"d\": d}\n",
    "\n",
    "# Neural IPC\n",
    "dataset = \"pallium_scVI_IPC_small\"\n",
    "x, y, _, _, d = load_dataset(root_path, dataset, seed=0)\n",
    "\n",
    "umap_file_name = f\"umap__k_15_metric_euclidean_epochs_750_seed_0_min_dist_0.1_init_pca.pkl\"\n",
    "embd = load_dict(os.path.join(root_path, dataset, umap_file_name)).embedding_\n",
    "data[\"pallium_scVI_IPC_small\"] = {\"embd\": d[\"UMAP\"], # embd,\n",
    "                                  \"y\": y,\n",
    "                                  \"d\": d}\n",
    "\n",
    "# other datasets\n",
    "cc_datasets = [\"neurosphere\", \"hippocampus\", \"HeLa2\", \"pancreas\"]\n",
    "# different provided meta data makes different type of loading necessary\n",
    "for dataset in cc_datasets[:2]:\n",
    "    tri_embd, y, d = load_small_cc_dataset(root_path, dataset=dataset, representation=\"tricycleEmbedding\", seed=0)\n",
    "    embd = load_dict(os.path.join(root_path, dataset+\"_gopca_small\", umap_file_name)).embedding_\n",
    "    data[dataset] = {\"embd\": tri_embd,\n",
    "                     \"y\": y,\n",
    "                     \"d\": d}\n",
    "for dataset in cc_datasets[2:]:\n",
    "    tri_embd, y, _, _, d = load_dataset(root_path, dataset+\"_tricycle\", 15, seed=0)\n",
    "    embd = load_dict(os.path.join(root_path, dataset+\"_gopca\", umap_file_name)).embedding_\n",
    "    data[dataset] = {\"embd\": tri_embd,\n",
    "                     \"y\": y,\n",
    "                     \"d\": d}\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "ac63b877d3976d39"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# plot figure\n",
    "mosaic=(\n",
    "    \"\"\"\n",
    "    abcdef\n",
    "    ghijkl\n",
    "    \"\"\"\n",
    "    \n",
    ")\n",
    "\n",
    "letters = \"abcedfghijkl\"\n",
    "\n",
    "fig, ax = plt.subplot_mosaic(mosaic=mosaic,\n",
    "                             #width_ratios=[1, 0.3],\n",
    "                             height_ratios=[1.5,1],\n",
    "                             figsize=(5.5, 2),\n",
    "                             per_subplot_kw={letter: {} if letter in \"abcdefg\" else {\"projection\": \"polar\"} for letter in letters}\n",
    "                             )\n",
    "\n",
    "# plot malaria with legend\n",
    "plot_scatter(ax=ax[\"a\"], \n",
    "             x=data[\"mca_ss2\"][\"embd\"], \n",
    "             y=[data[\"mca_ss2\"][\"d\"]['cluster_colors'][i] for i in data[\"mca_ss2\"][\"y\"]], \n",
    "             s=1, \n",
    "             alpha=1,\n",
    "             scalebar=False\n",
    "             )\n",
    "# dummy dots for legend\n",
    "dots = []\n",
    "for j in range(len(data[\"mca_ss2\"][\"d\"][\"cluster_colors\"])):\n",
    "    dot = mlines.Line2D([], [], color=data[\"mca_ss2\"][\"d\"][\"cluster_colors\"][j], marker='.', linestyle=\"none\",\n",
    "                          markersize=1, label=data[\"mca_ss2\"][\"d\"][\"cluster_print_names\"][j])\n",
    "    dots.append(dot)\n",
    "leg = ax[\"g\"].legend(handles=dots,  loc=(-0.2,0), ncol=2, frameon=False, handletextpad=0.1, columnspacing=1 )\n",
    "ax[\"g\"].axis(\"off\")\n",
    "\n",
    "leg.set_in_layout(False)\n",
    "\n",
    "\n",
    "# plot the datasets from Zheng et al.\n",
    "for i, dataset in enumerate(cc_datasets):\n",
    "    i +=1\n",
    "    plot_scatter(ax=ax[letters[i]], \n",
    "             x=data[dataset][\"embd\"], \n",
    "             y=data[dataset][\"d\"][\"colors\"][data[dataset][\"y\"]],\n",
    "             s=1, \n",
    "             alpha=1,\n",
    "             scalebar=False\n",
    "             )\n",
    "    # make panel below invisible\n",
    "    ax[letters[i+6]].xaxis.grid(False)\n",
    "    ax[letters[i+6]].yaxis.grid(False)\n",
    "    ax[letters[i+6]].spines[\"polar\"].set_visible(False)\n",
    "    # add cell cycle stage color ring\n",
    "    if dataset == \"pancreas\":\n",
    "        for j, k in enumerate([0, 5, 1, 2, 3,4]):\n",
    "            azimuths = np.arange(0, 60, 1) + j *60 \n",
    "            zeniths = np.arange(50, 70, 1)\n",
    "            values = azimuths * np.ones((20, 60))\n",
    "            ax[letters[i+6]].pcolormesh(azimuths*np.pi/180.0, zeniths, values, color=data[dataset][\"d\"][\"colors\"][k])\n",
    "            xticks = 2*np.pi / 6 * (np.arange(6) + 0.5)\n",
    "            ax[letters[i+6]].set_xticks(xticks)\n",
    "            ax[letters[i+6]].set_xticklabels(data[dataset][\"d\"][\"stage_names\"][[0, 5, 1, 2, 3, 4]])\n",
    "            ax[letters[i+6]].set_yticklabels([])\n",
    "            \n",
    "            ax[letters[i+6]].tick_params(pad=-1)    \n",
    "    else:\n",
    "        ax[letters[i+6]].axis(\"off\")\n",
    "        \n",
    "for i, (letter, dataset) in enumerate(zip(\"afbcde\", data)):\n",
    "    ax[letter].set_title(dataset_to_print[dataset])\n",
    "    ax[letter].set_title(\n",
    "    letter,\n",
    "    ha=\"right\",\n",
    "    loc=\"left\",\n",
    "    fontweight=\"bold\",\n",
    ")\n",
    "    \n",
    "# plot pallium embedding\n",
    "plot_scatter(ax=ax[\"f\"], \n",
    "             x=data[\"pallium_scVI_IPC_small\"][\"embd\"], \n",
    "             y=data[\"pallium_scVI_IPC_small\"][\"d\"][\"CellCycle\"],\n",
    "             cmap=\"hsv\",\n",
    "             s=1, \n",
    "             alpha=1,\n",
    "             scalebar=False\n",
    "             )\n",
    "ys = np.unique(data[\"pallium_scVI_IPC_small\"][\"d\"][\"CellCycle\"])\n",
    "\n",
    "# add pallium cell cycle color ring\n",
    "azimuths = np.arange(0, 361, 1)\n",
    "zeniths = np.arange(50, 70, 1)\n",
    "values = azimuths * np.ones((20, 361))\n",
    "\n",
    "ax[\"l\"].pcolormesh(azimuths*np.pi/180.0, zeniths, values, cmap=\"hsv\")\n",
    "ax[\"l\"].axis(\"off\")\n",
    "ax[\"l\"].set_title(\"Cell cycle\", va=\"top\")\n",
    "\n",
    "fig.savefig(os.path.join(fig_path, \"fig_sc_embds.pdf\"))"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "db420951d06c4ceb"
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "conda-env-ph-py",
   "language": "python",
   "display_name": "Python [conda env:ph]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
