{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=1\n",
    "import matplotlib.pyplot as plt\n",
    "from utils.utils import get_path\n",
    "from utils.io_utils import load_multiple_res\n",
    "from utils.pd_utils import get_persistent_feature_id, filter_dgms, compute_outlier_scores\n",
    "from utils.fig_utils import full_dist_to_print, full_dist_to_color, dataset_to_print, dist_to_print, dist_to_color, plot_edges_on_scatter\n",
    "from vis_utils.plot import plot_scatter\n",
    "from vis_utils.utils import load_dict, save_dict\n",
    "from vis_utils.loaders import load_dataset\n",
    "from persim import plot_diagrams\n",
    "import os\n",
    "import numpy as np\n",
    "import copy\n",
    "import umap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "style_file = \"utils.style\"\n",
    "plt.style.use(style_file)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "8e46f34ae86503da"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "root_path = get_path(\"data\")\n",
    "fig_path = os.path.join(root_path, \"figures\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "9214cce7086c8c80"
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Fig Malaria"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "14ba006f65ff33cc"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# hyperparameters\n",
    "k = 15\n",
    "dataset = \"mca_ss2\"\n",
    "x, y, sknn, pca2, d = load_dataset(root_path, dataset, k)\n",
    "print(x.shape)\n",
    "seeds = [0, 1, 2]"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "7ac72041ab9cafd"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# load umap embedding of malaria dataset\n",
    "umap_file_name = f\"{dataset}_umap_correlation_k_10_min_dist_1.0_spread_2_seed_2_init_spectral.pkl\"\n",
    "\n",
    "try:\n",
    "    embd = load_dict(os.path.join(root_path, dataset, umap_file_name)).embedding_\n",
    "except FileNotFoundError:\n",
    "    umapper = umap.UMAP(n_neighbors=10,\n",
    "                        metric=\"correlation\",\n",
    "                        verbose=True,\n",
    "                        min_dist=1.0,\n",
    "                        spread=2,\n",
    "                        random_state=2,\n",
    "                        init = \"spectral\",\n",
    "                        )\n",
    "    _ = umapper.fit_transform(x)\n",
    "    save_dict(umapper, os.path.join(root_path, dataset, umap_file_name))\n",
    "\n",
    "    embd = load_dict(os.path.join(root_path, dataset, umap_file_name)).embedding_"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "7ff8949bd2b31f4f"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# distances for illustration of representatives on malaria dataset\n",
    "distances_corr = {\n",
    "    \"correlation\": [{}],\n",
    "    \"dtm\": [\n",
    "        {\"k\": 15, \"p_dtm\": np.inf, \"p_radius\": np.inf},\n",
    "    ],\n",
    "    \"eff_res\":[\n",
    "        {\"corrected\": True, \"weighted\": False, \"k\": 15, \"disconnect\": True},\n",
    "    ],\n",
    "    \"diffusion\": [\n",
    "        {\"k\": 15, \"t\": 64, \"kernel\": \"sknn\", \"include_self\": False},\n",
    "    ]\n",
    "}\n",
    "for distance in distances_corr:\n",
    "    if distance in [\"euclidean\", \"cosine\", \"correlation\"]:\n",
    "        continue\n",
    "    for dist_kwargs in distances_corr[distance]:\n",
    "        dist_kwargs[\"input_distance\"] = \"correlation\""
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "8fa1b5aa4e56af16"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# load PH results\n",
    "all_res_corr = load_multiple_res(datasets=dataset, n=None, embd_dims=None, sigmas=None, distances=distances_corr, seeds=seeds, root_path=root_path, n_threads=1)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "35f6a52038b0606a"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# compute detection scores\n",
    "outlier_scores_corr = compute_outlier_scores(all_res_corr, dim=1, n_features=2, return_mean=False)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "36aadfef8bd82f"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# plot figure\n",
    "highlight_colors = [plt.get_cmap(\"tab10\")(i) for i in [2, 5]]\n",
    "size = 6\n",
    "\n",
    "letters = \"abcdefgh\"\n",
    "titles = [\"Correlation\", \"DTM\", \"Eff. resistance\", \"Diffusion\"]\n",
    "linestyles = [\"solid\", \"dotted\"]\n",
    "\n",
    "# postitions and angles for the arrows\n",
    "head_positions = [[(0.09, 0.45), (0.4, 0.6)],\n",
    "                  [(0.4, 0.75), (0.09, 0.45)],\n",
    "                  [(0.6, 0.4), (0.2, 0.75)],\n",
    "                  [(0.25, 0.7), (0.6, 0.2)],\n",
    "                  ]\n",
    "\n",
    "angles =[[(45, -60), (45, -60)],\n",
    "         [(-20, 40), (60, 130)],\n",
    "         [(0, 90), (0, 90)],\n",
    "         [(0, 90), (0, 90)],\n",
    "         ]\n",
    "\n",
    "mosaic = \"\"\"\n",
    "abcd\n",
    "efgh\"\"\"\n",
    "fig, ax_dict = plt.subplot_mosaic(mosaic=mosaic, height_ratios=[1, 0.5], figsize=(5.5, 2.))\n",
    "\n",
    "ax = np.array([[ax_dict[letter] for letter in row] for row in mosaic.split(\"\\n\") if row])\n",
    "\n",
    "# for each distance plot the scatter plot and the persistence diagram\n",
    "for i, dist in enumerate(all_res_corr):\n",
    "    full_dist = list(all_res_corr[dist].keys())[0]\n",
    "    \n",
    "    res = all_res_corr[dist][full_dist][0]\n",
    "    \n",
    "    # persistence diagram\n",
    "    plot_diagrams(res[\"dgms\"],\n",
    "              ax=ax[1, i],\n",
    "              plot_only=[1], \n",
    "              size=size, \n",
    "              color=\"k\",\n",
    "              colormap=style_file, # necessary bc plot_diagrams uses the colormap as mpl style\n",
    "              )\n",
    "    ax[1, i].legend().set_visible(False)\n",
    "    ax[1, i].set_xticks([])\n",
    "    ax[1, i].set_yticks([])\n",
    "    ax[1, i].set_xticklabels([])\n",
    "    ax[1, i].set_yticklabels([])\n",
    "    ax[0, i].set_title(titles[i])\n",
    "    ax[0, i].set_title(\n",
    "        letters[i],\n",
    "        loc=\"left\",\n",
    "        ha=\"right\",\n",
    "        fontweight=\"bold\",\n",
    ")        \n",
    "    # scattter plots and representatives    \n",
    "    # get indices of the 2 most persistent loops\n",
    "    idx = [get_persistent_feature_id(res, dim=1, m=m+1) for m in range(2)] \n",
    "\n",
    "    plot_scatter(ax=ax[0, i],\n",
    "                 x=embd,\n",
    "                 y=[d[\"cluster_colors\"][i] for i in y],\n",
    "                 scalebar=False,\n",
    "                 alpha=0.5,)\n",
    "    \n",
    "    for j, id in enumerate(idx):\n",
    "        ax[1, i].scatter(*res[\"dgms\"][1][id].T,\n",
    "                         color=highlight_colors[j],\n",
    "                         s=size+2\n",
    "                         )\n",
    "        plot_edges_on_scatter(ax=ax[0, i],\n",
    "                             edge_idx = res[\"cycles\"][1][id],\n",
    "                             x=embd,\n",
    "                             color=highlight_colors[j],\n",
    "                             linewidth=0.5,\n",
    "                             linestyle=linestyles[j],\n",
    "                             )\n",
    "        \n",
    "        # plot arrow from scatter plot to persistence diagram\n",
    "        if i == 3:\n",
    "            annot = plt.Annotation(\n",
    "            \"\",\n",
    "            xy=head_positions[i][j],\n",
    "            xycoords=ax[0, i].transAxes,\n",
    "            xytext=(res[\"dgms\"][1][id]), \n",
    "            textcoords=ax[1, i].transData,\n",
    "            arrowprops=dict(arrowstyle=\"->\", linewidth=0.5, color=highlight_colors[j], connectionstyle=f\"angle3,angleA={angles[i][j][0]},angleB={angles[i][j][1]}\"),\n",
    "             )\n",
    "            ax[1, i].add_artist(annot)\n",
    "\n",
    "fig.savefig(os.path.join(fig_path, \"fig_malaria.pdf\"))"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "f1bdf95b99848f74"
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Fig all single-cell datasets"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "e1a7d7dcc8be67f4"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# hyperparameters\n",
    "distances = {\n",
    "    \"euclidean\": [{}],\n",
    "    \"correlation\": [{}],\n",
    "    \"fermat\": [\n",
    "        {\"p\": 1},\n",
    "        {\"p\": 2},\n",
    "        {\"p\": 3},\n",
    "        {\"p\": 5},\n",
    "        {\"p\": 7}\n",
    "    ],\n",
    "    \"dtm\": [\n",
    "            {\"k\": 4, \"p_dtm\": 2, \"p_radius\": 1},\n",
    "            {\"k\": 4, \"p_dtm\": np.inf, \"p_radius\": 1},\n",
    "            {\"k\": 15, \"p_dtm\": 2, \"p_radius\": 1},\n",
    "            {\"k\": 15, \"p_dtm\": np.inf, \"p_radius\": 1},\n",
    "            {\"k\": 100, \"p_dtm\": 2, \"p_radius\": 1},\n",
    "            {\"k\": 100, \"p_dtm\": np.inf, \"p_radius\": 1},\n",
    "\n",
    "            {\"k\": 4, \"p_dtm\": 2, \"p_radius\": 2},\n",
    "            {\"k\": 4, \"p_dtm\": np.inf, \"p_radius\": 2},\n",
    "            {\"k\": 15, \"p_dtm\": 2, \"p_radius\": 2},\n",
    "            {\"k\": 15, \"p_dtm\": np.inf, \"p_radius\": 2},\n",
    "            {\"k\": 100, \"p_dtm\": 2, \"p_radius\": 2},\n",
    "            {\"k\": 100, \"p_dtm\": np.inf, \"p_radius\": 2},\n",
    "\n",
    "            {\"k\": 4, \"p_dtm\": 2, \"p_radius\": np.inf},\n",
    "            {\"k\": 4, \"p_dtm\": np.inf, \"p_radius\": np.inf},\n",
    "            {\"k\": 15, \"p_dtm\": 2, \"p_radius\": np.inf},\n",
    "            {\"k\": 15, \"p_dtm\": np.inf, \"p_radius\": np.inf},\n",
    "            {\"k\": 100, \"p_dtm\": 2, \"p_radius\": np.inf},\n",
    "            {\"k\": 100, \"p_dtm\": np.inf, \"p_radius\": np.inf},\n",
    "    ],\n",
    "    \"tsne_embd\": [\n",
    "        {\"perplexity\": 8, \"n_epochs\": 500, \"n_early_epochs\": 250, \"rescale_tsne\": True},\n",
    "        {\"perplexity\": 30, \"n_epochs\": 500, \"n_early_epochs\": 250, \"rescale_tsne\": True},\n",
    "        {\"perplexity\": 333, \"n_epochs\": 500, \"n_early_epochs\": 250, \"rescale_tsne\": True}\n",
    "    ],\n",
    "    \"umap_embd\": [\n",
    "        {\"k\": 100, \"n_epochs\": 750, \"min_dist\": 0.1},\n",
    "        {\"k\": 999, \"n_epochs\": 750, \"min_dist\": 0.1},\n",
    "    ],\n",
    "    \"eff_res\": [\n",
    "        {\"corrected\": True, \"weighted\": False, \"k\": 15, \"disconnect\": True},\n",
    "        {\"corrected\": True, \"weighted\": False, \"k\": 100, \"disconnect\": True},\n",
    "    ],\n",
    "    \"diffusion\": [\n",
    "        {\"k\": 15, \"t\": 8, \"kernel\": \"sknn\", \"include_self\": False},\n",
    "        {\"k\": 100, \"t\": 8, \"kernel\": \"sknn\", \"include_self\": False},\n",
    "        {\"k\": 15, \"t\": 64, \"kernel\": \"sknn\", \"include_self\": False},\n",
    "        {\"k\": 100, \"t\": 64, \"kernel\": \"sknn\", \"include_self\": False},\n",
    "    ],\n",
    "    \"spectral\": [\n",
    "        {\"k\": 15, \"normalization\": \"none\", \"n_evecs\": 2, \"weighted\": False},\n",
    "        {\"k\": 15, \"normalization\": \"none\", \"n_evecs\": 5, \"weighted\": False},\n",
    "        {\"k\": 15, \"normalization\": \"none\", \"n_evecs\": 10, \"weighted\": False},\n",
    "    ],\n",
    "}\n",
    "\n",
    "datasets = [\"mca_ss2\", \"neurosphere_gopca_small\", \"hippocampus_gopca_small\", \"pallium_scVI_IPC_small\", \"HeLa2_gopca\", \"pancreas_gopca\"]\n",
    "\n",
    "seeds = [0, 1, 2]\n",
    "\n",
    "n_loops = {dataset: 1 if dataset != \"mca_ss2\" else 2 for dataset in datasets}"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "f315006a23b905c4"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# make copy of the distance dict with correlation as input distance for malaria dataset\n",
    "distances_corr = copy.deepcopy(distances)\n",
    "for distance in distances_corr:\n",
    "    if distance in [\"euclidean\", \"cosine\", \"correlation\"]:\n",
    "        continue\n",
    "    for dist_kwargs in distances_corr[distance]:\n",
    "        dist_kwargs[\"input_distance\"] = \"correlation\"\n",
    "        if \"metric\" in dist_kwargs:\n",
    "            del dist_kwargs[\"metric\"]"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "fbab9f9333c15c12"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# load PH results (using correlation as input distance for malaria dataset)\n",
    "all_res = {}\n",
    "for dataset in datasets:\n",
    "    dists = distances if dataset != \"mca_ss2\" else distances_corr\n",
    "    all_res[dataset] = load_multiple_res(datasets=dataset, \n",
    "                                         n=None, \n",
    "                                         embd_dims=None,\n",
    "                                         sigmas=None,\n",
    "                                         distances=dists,\n",
    "                                         seeds=seeds,\n",
    "                                         root_path=root_path,\n",
    "                                         n_threads=1)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "4668b94f1fb50ec6"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# filtering \n",
    "dob = 1.25\n",
    "all_res = filter_dgms(all_res, dob=dob, binary=True, dim=1)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "e747c017013466eb"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# compute detection scores\n",
    "outlier_scores = {}\n",
    "for dataset in datasets:\n",
    "    outlier_scores[dataset] = compute_outlier_scores(all_res[dataset], n_features=n_loops[dataset], dim=1)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "7c4e363a1d7045be"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# select best hyperparameter setting for each distance on each dataset\n",
    "best_method = {}\n",
    "for dataset in datasets:\n",
    "    best_full_dist_on_dataset = {}\n",
    "    for dist in distances:\n",
    "        best_mean = 0\n",
    "        for full_dist in outlier_scores[dataset][dist]:\n",
    "            if dataset != \"mca_ss2\":\n",
    "                full_dist = full_dist.removesuffix(\"_input_distance_correlation\")\n",
    "            mean = outlier_scores[dataset][dist][full_dist].mean()\n",
    "            if best_mean <= mean:\n",
    "                best_mean = mean\n",
    "                best_full_dist_on_dataset[dist] = full_dist\n",
    "    best_method[dataset] = best_full_dist_on_dataset"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "55448d8fa43d21ba"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# results of manual evaluation of the correctness of the most persistent detected loop\n",
    "correct_detection = {\n",
    "    \"mca_ss2\": {\n",
    "        \"correlation\": False,\n",
    "        \"fermat\": False,\n",
    "        \"dtm\": False,\n",
    "        \"tsne_embd\": False,\n",
    "        \"umap_embd\": False,\n",
    "        \"eff_res\": True,\n",
    "        \"diffusion\": True,\n",
    "        \"spectral\": True,\n",
    "    },\n",
    "    \"neurosphere_gopca_small\": {\n",
    "        \"euclidean\": True,\n",
    "        \"fermat\": True,\n",
    "        \"dtm\": True,\n",
    "        \"tsne_embd\": True,\n",
    "        \"umap_embd\": True,\n",
    "        \"eff_res\": True,\n",
    "        \"diffusion\": True,\n",
    "        \"spectral\": True,\n",
    "    },\n",
    "    \"hippocampus_gopca_small\": {\n",
    "        \"euclidean\": True,\n",
    "        \"fermat\": False,\n",
    "        \"dtm\": True,\n",
    "        \"tsne_embd\": True,\n",
    "        \"umap_embd\": True,\n",
    "        \"eff_res\": True,\n",
    "        \"diffusion\": True,\n",
    "        \"spectral\": True,\n",
    "    },\n",
    "    \"pallium_scVI_IPC_small\": {\n",
    "        \"euclidean\": True,\n",
    "        \"fermat\": True,\n",
    "        \"dtm\": True,\n",
    "        \"tsne_embd\": True,\n",
    "        \"umap_embd\": False,\n",
    "        \"eff_res\": True,\n",
    "        \"diffusion\": True,\n",
    "        \"spectral\": True,\n",
    "    },\n",
    "    \"HeLa2_gopca\": {\n",
    "        \"euclidean\": False,\n",
    "        \"fermat\": False,\n",
    "        \"dtm\": True,\n",
    "        \"tsne_embd\": True,\n",
    "        \"umap_embd\": True,\n",
    "        \"eff_res\": True,\n",
    "        \"diffusion\": True,\n",
    "        \"spectral\": True,\n",
    "    },\n",
    "    \"pancreas_gopca\": {\n",
    "        \"euclidean\": False,\n",
    "        \"fermat\": False,\n",
    "        \"dtm\": False,\n",
    "        \"tsne_embd\": True,\n",
    "        \"umap_embd\": True,\n",
    "        \"eff_res\": True,\n",
    "        \"diffusion\": False,\n",
    "        \"spectral\": False,\n",
    "    },\n",
    "}"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "e7be5cdfdf6a4bfa"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# for inspection of the best methods per dataset --> Table S2\n",
    "best_method[\"pancreas_gopca\"]"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "e34c78874ea0afa5"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# plot figure\n",
    "plt.rcParams.update({'hatch.color': 'w'})\n",
    "fig, ax = plt.subplots(ncols=6, nrows=1, figsize=(5.5, 1.25))\n",
    "\n",
    "for i, dataset in enumerate(datasets):\n",
    "    cax = ax[i]\n",
    "    shift = 0\n",
    "    for j, dist in enumerate(best_method[dataset]):\n",
    "        full_dist = best_method[dataset][dist]           \n",
    "        full_dist_for_color = full_dist.removesuffix(\"_input_distance_correlation\")\n",
    "        \n",
    "        # for malaria, we plot correlation, for the others we plot euclidean distance. The shift argument handles selecting the correct one\n",
    "        if dataset == \"mca_ss2\":\n",
    "            if full_dist == \"euclidean\":\n",
    "                shift = -1\n",
    "                continue\n",
    "            elif full_dist == \"correlation\":\n",
    "                full_dist_for_color = \"euclidean\"\n",
    "                full_dist_for_print = \"correlation\"\n",
    "        else:\n",
    "            if full_dist == \"correlation\":\n",
    "                shift = -1\n",
    "                continue\n",
    "        \n",
    "        j += shift  \n",
    "        \n",
    "        \n",
    "        mean = outlier_scores[dataset][dist][full_dist].mean()\n",
    "        std = outlier_scores[dataset][dist][full_dist].std()        \n",
    "        \n",
    "        if dist == \"umap_embd\":\n",
    "            full_dist_for_color += \"_metric_euclidean\"\n",
    "            \n",
    "        if dataset == \"mca_ss2\" and dist == \"correlation\":\n",
    "            dist_for_color = \"euclidean\"\n",
    "        else:\n",
    "            dist_for_color = dist\n",
    "        \n",
    "        # plot the bar with hatching if detection was deemed incorrect\n",
    "        if correct_detection[dataset][dist]:\n",
    "            cax.bar(j, mean, width=0.8, yerr=std, label=dist, color=dist_to_color[dist_for_color])\n",
    "        else:\n",
    "            cax.bar(j, mean, width=0.8, yerr=std, label=dist, color=dist_to_color[dist_for_color], hatch=\"////\", alpha=1)\n",
    "         \n",
    "        # prettify panel\n",
    "        cax.set_ylim(0.0, 1.0)\n",
    "        cax.set_xticks([])\n",
    "        cax.set_xticklabels([])\n",
    "    \n",
    "        cax.set_yticks([0, 0.25, 0.5, 0.75, 1.0])\n",
    "        \n",
    "        if i==0:\n",
    "            cax.set_ylabel(\"Detection score\")\n",
    "            cax.set_yticklabels([0, 0.25, 0.5, 0.75, 1.0])\n",
    "        else:\n",
    "            cax.set_yticklabels([])\n",
    "\n",
    "        cax.set_title(dataset_to_print[dataset])\n",
    "        cax.set_title(\n",
    "            letters[i],\n",
    "            loc=\"left\",\n",
    "            fontweight=\"bold\",\n",
    "            ha=\"right\",\n",
    "        )\n",
    "\n",
    "    cax.spines['left'].set_position(('outward', 5))\n",
    "    cax.set_xlim(-0.5, len(best_method[dataset])-1.5)\n",
    "       \n",
    "    cax.set_xticks(range(len(best_method[dataset])-1))\n",
    "    \n",
    "    # add the distance names as xlables, use correlation for malaraia, euclidean for the others\n",
    "    if i==0:\n",
    "        cax.set_xticklabels([\"Correlation\"]+[dist_to_print[dist] for dist in list(best_method[\"neurosphere_gopca_small\"])[2:]], rotation=90)\n",
    "    else:\n",
    "        cax.set_xticklabels([\"Euclidean\"]+[dist_to_print[dist] for dist in list(best_method[\"neurosphere_gopca_small\"])[2:]], rotation=90)\n",
    "\n",
    "fig.savefig(os.path.join(fig_path, \"fig_scRNAseq_sep_filtered.pdf\"))    "
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "c07d978c8d52ab6"
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Fig with all methods for all datasets"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "d4713cbcdaf73887"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# hyperparameters\n",
    "distances = {\n",
    "    \"euclidean\": [{}],\n",
    "    \"correlation\": [{}],\n",
    "    \"fermat\": [\n",
    "        {\"p\": 2},\n",
    "        {\"p\": 3},\n",
    "        {\"p\": 5},\n",
    "        {\"p\": 7}\n",
    "    ],\n",
    "    \"core\": [\n",
    "        {\"k\": 15},\n",
    "        {\"k\": 100}],\n",
    "    \"sknn_dist\": [\n",
    "        {\"k\": 15},\n",
    "        {\"k\": 100}\n",
    "    ],\n",
    "\n",
    "    \"tsne\": [\n",
    "         {\"perplexity\": 30},\n",
    "         {\"perplexity\": 200},\n",
    "         {\"perplexity\": 333}\n",
    "    ],\n",
    "    \"umap\": [\n",
    "         {\"k\": 100, \"use_rho\": True, \"include_self\": True},\n",
    "         {\"k\": 999, \"use_rho\": True, \"include_self\": True},\n",
    "    ],\n",
    "    \"tsne_embd\": [\n",
    "        {\"perplexity\": 8, \"n_epochs\": 500, \"n_early_epochs\": 250, \"rescale_tsne\": True},\n",
    "        {\"perplexity\": 30, \"n_epochs\": 500, \"n_early_epochs\": 250, \"rescale_tsne\": True},\n",
    "        {\"perplexity\": 333, \"n_epochs\": 500, \"n_early_epochs\": 250, \"rescale_tsne\": True}\n",
    "    ],\n",
    "    \"umap_embd\": [\n",
    "        {\"k\": 15, \"n_epochs\": 750, \"min_dist\": 0.1},\n",
    "        {\"k\": 100, \"n_epochs\": 750, \"min_dist\": 0.1},\n",
    "        {\"k\": 999, \"n_epochs\": 750, \"min_dist\": 0.1},\n",
    "    ],\n",
    "    \"eff_res\": [\n",
    "        {\"corrected\": True, \"weighted\": False, \"k\": 15, \"disconnect\": True},\n",
    "        {\"corrected\": True, \"weighted\": False, \"k\": 100, \"disconnect\": True},\n",
    "    ],\n",
    "    \"diffusion\": [\n",
    "        {\"k\": 15, \"t\": 8, \"kernel\": \"sknn\", \"include_self\": False},\n",
    "        {\"k\": 100, \"t\": 8, \"kernel\": \"sknn\", \"include_self\": False},\n",
    "        {\"k\": 15, \"t\": 64, \"kernel\": \"sknn\", \"include_self\": False},\n",
    "        {\"k\": 100, \"t\": 64, \"kernel\": \"sknn\", \"include_self\": False},\n",
    "    ],\n",
    "    \"spectral\": [\n",
    "        {\"k\": 15, \"normalization\": \"none\", \"n_evecs\": 2, \"weighted\": False},\n",
    "        {\"k\": 15, \"normalization\": \"none\", \"n_evecs\": 5, \"weighted\": False},\n",
    "        {\"k\": 15, \"normalization\": \"none\", \"n_evecs\": 10, \"weighted\": False},\n",
    "    ],\n",
    "}\n",
    "\n",
    "datasets = [\"mca_ss2\", \"neurosphere_gopca_small\", \"hippocampus_gopca_small\", \"pallium_scVI_IPC_small\", \"HeLa2_gopca\", \"pancreas_gopca\"]\n",
    "\n",
    "seeds = [0, 1, 2]\n",
    "\n",
    "n_loops = {dataset: 1 if dataset != \"mca_ss2\" else 2 for dataset in datasets}"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "a8882af3715e374f"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# add correlation as input distance as malaria will use it \n",
    "distances_corr = copy.deepcopy(distances)\n",
    "for distance in distances_corr:\n",
    "    if distance in [\"euclidean\", \"cosine\", \"correlation\"]:\n",
    "        continue\n",
    "    for dist_kwargs in distances_corr[distance]:\n",
    "        dist_kwargs[\"input_distance\"] = \"correlation\"\n",
    "        if \"metric\" in dist_kwargs:\n",
    "            del dist_kwargs[\"metric\"]"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "4990a82230b71303"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# load the PH results (using correlation as input distance for malaria)\n",
    "all_res = {}\n",
    "for dataset in datasets:\n",
    "    dists = distances if dataset != \"mca_ss2\" else distances_corr\n",
    "    all_res[dataset] = load_multiple_res(datasets=dataset, \n",
    "                                         n=None, \n",
    "                                         embd_dims=None,\n",
    "                                         sigmas=None,\n",
    "                                         distances=dists,\n",
    "                                         seeds=seeds,\n",
    "                                         root_path=root_path,\n",
    "                                         n_threads=1)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "ed5821486c15eaee"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# filtering\n",
    "dob = 1.25\n",
    "all_res = filter_dgms(all_res, dob=dob, binary=True, dim=1)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "77e3b3f1362c3392"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# compute the detection scores\n",
    "outlier_scores = {}\n",
    "for dataset in datasets:\n",
    "    outlier_scores[dataset] = compute_outlier_scores(all_res[dataset], n_features=n_loops[dataset], dim=1)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "beee14e2cabd72ef"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# plot figure\n",
    "plt.rcParams.update({'hatch.color': 'w'})\n",
    "fig, ax = plt.subplots(ncols=1, nrows=len(datasets), figsize=(5.5, 7))\n",
    "\n",
    "bar_width = 0.2\n",
    "sep = 1.5 * bar_width\n",
    "\n",
    "for i, dataset in enumerate(datasets):\n",
    "    # one row per dataset\n",
    "    cax = ax[i]\n",
    "    x_vals = [0]\n",
    "    shift = 0\n",
    "    full_dists = []\n",
    "    for j, dist in enumerate(outlier_scores[dataset]):    \n",
    "        # show correlation distance for malaria and euclidean distance for the other dataset. The shift handles the correct distance selection.\n",
    "        if dataset == \"mca_ss2\":\n",
    "            if dist == \"euclidean\":\n",
    "                shift = -1\n",
    "                continue\n",
    "            elif dist == \"correlation\":\n",
    "                full_dist_for_color = \"euclidean\"\n",
    "                full_dist_for_print = \"correlation\"\n",
    "        else:\n",
    "            if dist == \"correlation\":\n",
    "                shift = -1\n",
    "                continue\n",
    "                \n",
    "        j+=shift\n",
    "        x_val = x_vals[-1] + sep\n",
    "        \n",
    "        # plot all hyperparameter results for this distance in on block\n",
    "        for k, full_dist in enumerate(outlier_scores[dataset][dist]):\n",
    "            x_val += bar_width\n",
    "            x_vals.append(x_val)\n",
    "\n",
    "            full_dist_for_color = full_dist.removesuffix(\"_input_distance_correlation\")\n",
    "                            \n",
    "            if dist == \"umap_embd\":\n",
    "                full_dist_for_color += \"_metric_euclidean\"\n",
    "                \n",
    "            if dataset == \"mca_ss2\" and dist == \"correlation\":\n",
    "                full_dist_for_color = \"euclidean\"\n",
    "            else:\n",
    "                full_dist_for_color = full_dist_for_color\n",
    "                \n",
    "            full_dists.append(full_dist_for_color)\n",
    "\n",
    "            cax.bar(x_val, \n",
    "                   outlier_scores[dataset][dist][full_dist].mean(), \n",
    "                   bar_width, \n",
    "                   yerr=outlier_scores[dataset][dist][full_dist].std(), \n",
    "                   label=full_dist_to_print[full_dist_for_color],\n",
    "                   color = full_dist_to_color[full_dist_for_color]\n",
    "                   )\n",
    "        \n",
    "    # prettify the panel\n",
    "    cax.set_ylim(0.0, 1.0)\n",
    "    cax.set_xticks([])\n",
    "    cax.set_xticklabels([])\n",
    "\n",
    "    cax.set_yticks([0, 0.25, 0.5, 0.75, 1.0])\n",
    "    \n",
    "    cax.set_ylabel(\"Detection score\")\n",
    "    cax.set_yticklabels([0, 0.25, 0.5, 0.75, 1.0])\n",
    "\n",
    "    cax.set_title(dataset_to_print[dataset], loc=\"left\")\n",
    "    cax.text(\n",
    "        -0.2,\n",
    "        1.05,\n",
    "        letters[i],\n",
    "        fontweight=\"bold\"\n",
    "    )\n",
    "\n",
    "    cax.spines['left'].set_position(('outward', 5))\n",
    "\n",
    "    cax.set_xticks(x_vals[1:])\n",
    "    \n",
    "    if i == 5:\n",
    "        cax.set_xticklabels([\"Euclidean / Correlation\"] + [full_dist_to_print[full_dist].replace(\"\\n\", \" \") for full_dist in full_dists[1:]], \n",
    "                            rotation=90)\n",
    "\n",
    "fig.savefig(os.path.join(fig_path, \"fig_scRNAseq_sep_filtered_all_methods.pdf\"))    "
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "965cf8644264ab1d"
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "conda-env-ph-py",
   "language": "python",
   "display_name": "Python [conda env:ph]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
