{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=1\n",
    "import matplotlib.pyplot as plt\n",
    "from utils.utils import get_path\n",
    "from utils.io_utils import load_multiple_res\n",
    "from utils.pd_utils import compute_outlier_scores, transform_dgms, filter_dgms\n",
    "from utils.fig_utils import full_dist_to_print, dist_to_color, dist_to_print\n",
    "import os\n",
    "import numpy as np"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "96d7c180050a0097"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "style_file = \"utils.style\"\n",
    "plt.style.use(style_file)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "c28039af3e46ff6d"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "root_path = get_path(\"data\")\n",
    "fig_path = os.path.join(root_path, \"figures\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "661bd1b571f735bb"
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Fig many methods on a toy circle"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "443292aeb2c995aa"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# hyperparameters\n",
    "\n",
    "# hyperparameters which produce the best AUC per distance type\n",
    "distances = {\n",
    "    \"euclidean\": [{}],\n",
    "    \"fermat\": [\n",
    "               {\"p\": 3},\n",
    "           ],\n",
    "    \"dtm\": [\n",
    "            {\"k\": 4, \"p_dtm\": 2, \"p_radius\": 1},\n",
    "                    ],\n",
    "    \"core\": [\n",
    "        {\"k\": 15},\n",
    "    ],\n",
    "    \"sknn_dist\": [\n",
    "        {\"k\": 100}\n",
    "    ],\n",
    "    \"tsne\": [\n",
    "         {\"perplexity\": 30},\n",
    "    ],\n",
    "    \"umap\": [\n",
    "         {\"k\": 100, \"use_rho\": True, \"include_self\": True},\n",
    "    ],\n",
    "    \"tsne_embd\": [\n",
    "        {\"perplexity\": 30, \"n_epochs\": 500, \"n_early_epochs\": 250, \"rescale_tsne\": True},\n",
    "    ],\n",
    "    \"umap_embd\": [\n",
    "        {\"k\": 15, \"n_epochs\": 750, \"min_dist\": 0.1, \"metric\": \"euclidean\"},\n",
    "    ],\n",
    "    \"eff_res\": [\n",
    "        {\"corrected\": True, \"weighted\": False, \"k\": 100, \"disconnect\": True},\n",
    "        ],\n",
    "    \"diffusion\":[\n",
    "        {\"k\": 100, \"t\": 8, \"kernel\": \"sknn\", \"include_self\": False},\n",
    "    ],\n",
    "    \"spectral\": [\n",
    "        {\"k\": 15, \"normalization\": \"none\", \"n_evecs\": 2, \"weighted\": False},\n",
    "    ]\n",
    "}\n",
    "\n",
    "n = 1000\n",
    "embd_dim = 50\n",
    "sigmas = np.linspace(0.0, 0.35, 29)\n",
    "sigmas = np.array([np.format_float_positional(sigma, precision=4, unique=True, trim='0') for sigma in sigmas]).astype(float)\n",
    "seeds = [0, 1, 2]"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "36a1ecfbc5377930"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "all_res = load_multiple_res(datasets=\"toy_circle\", distances=distances, root_path=root_path, n=n, embd_dims=embd_dim, sigmas=sigmas, seeds=seeds, n_threads=10)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "9ad0356b7bbf5419"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# filtering\n",
    "dob = 1.25\n",
    "all_res = filter_dgms(all_res, dob=dob, dim=1, binary=True)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "40fe56a63adcbe96"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# compute outlier scores\n",
    "outlier_scores = compute_outlier_scores(dgms=all_res, n_features=1, dim=1)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "3539d72e9b91f119"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# plot figure\n",
    "groups = [[\"fermat\", \"dtm\", \"core\"], [\"sknn_dist\", \"umap\", \"tsne\"], [\"umap_embd\", \"tsne_embd\"], [\"eff_res\", \"diffusion\", \"spectral\"]]\n",
    "panel_labels = [\"Density-based\", \"Graph-based\", \"Embedding-based\", \"Spectral\"]\n",
    "letters = \"abcd\"\n",
    "\n",
    "fig, ax = plt.subplots(ncols=len(groups),\n",
    "                       figsize=(5.5, 5.5/3))\n",
    "\n",
    "\n",
    "mean_eucl = outlier_scores[\"euclidean\"][\"euclidean\"].mean(1)\n",
    "std_eucl = outlier_scores[\"euclidean\"][\"euclidean\"].std(1)\n",
    "\n",
    "\n",
    "for i, group in enumerate(groups):\n",
    "    for distance in group:\n",
    "        for full_dist in outlier_scores[distance]:\n",
    "            mean = outlier_scores[distance][full_dist].mean(1)\n",
    "            std = outlier_scores[distance][full_dist].std(1)\n",
    "            ax[i].plot(sigmas,\n",
    "                       mean ,\n",
    "                       label=full_dist_to_print[full_dist].replace(\"\\n\", \" \"),\n",
    "                       c=dist_to_color[distance],\n",
    "                       clip_on=False\n",
    "            )\n",
    "            ax[i].fill_between(sigmas,\n",
    "                               mean+std,\n",
    "                               mean-std,\n",
    "                               alpha=0.2,\n",
    "                               #color=full_dist_to_color[full_dist],\n",
    "                               color=dist_to_color[distance],\n",
    "                               edgecolor=None\n",
    "                               )\n",
    "\n",
    "    if i==2:\n",
    "        ax[i].plot(sigmas, \n",
    "                   mean_eucl, \n",
    "                   c=\"k\", \n",
    "                   linestyle=\"dashed\", \n",
    "                   #label=full_dist_to_print[\"euclidean\"], \n",
    "                   label=dist_to_print[\"euclidean\"], \n",
    "                   clip_on=False)\n",
    "    else:\n",
    "                ax[i].plot(sigmas, mean_eucl, c=\"k\", linestyle=\"dashed\", clip_on=False)\n",
    "\n",
    "    ax[i].fill_between(sigmas,\n",
    "                       mean_eucl+std_eucl,\n",
    "                       mean_eucl-std_eucl,\n",
    "                       color=\"k\",\n",
    "                       alpha=0.2,\n",
    "                       edgecolor=None)\n",
    "    \n",
    "    ax[i].legend(loc='upper center',\n",
    "                 bbox_to_anchor=(0.5, -0.25),\n",
    "                 frameon=False\n",
    "          )\n",
    "\n",
    "    \n",
    "    ax[i].set_ylim(0, 1)\n",
    "    ax[i].set_xlim(0, 0.35)\n",
    "    ax[i].set_xlabel(\"Noise std $\\sigma$\")\n",
    "    if i==0:\n",
    "        ax[i].set_ylabel(\"Loop detection score\")\n",
    "        \n",
    "    if i > 0:\n",
    "        ax[i].set_yticklabels([])\n",
    "        \n",
    "    ax[i].set_title(panel_labels[i])\n",
    "    ax[i].set_title(\n",
    "        letters[i],\n",
    "        loc=\"left\",\n",
    "        ha=\"right\",\n",
    "        fontweight=\"bold\",\n",
    ")\n",
    "    \n",
    "fig.savefig(os.path.join(fig_path, \"fig_toy_circle_dob.pdf\"))"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "d4ce3a806b290f95"
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Fig eff res on circle"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "72fe527233f8cd8c"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "distances = {\n",
    "    \"euclidean\": [{}],\n",
    "    \"eff_res\": [\n",
    "        {\"corrected\": True, \"weighted\": False, \"k\": 15, \"disconnect\": True},\n",
    "        {\"corrected\": True, \"weighted\": False, \"k\": 100, \"disconnect\": True},\n",
    "        {\"corrected\": True, \"weighted\": True, \"k\": 15, \"disconnect\": True},\n",
    "        {\"corrected\": True, \"weighted\": True, \"k\": 100, \"disconnect\": True},\n",
    "        {\"corrected\": False, \"weighted\": True, \"k\": 15, \"disconnect\": True},\n",
    "        {\"corrected\": False, \"weighted\": True, \"k\": 100, \"disconnect\": True},\n",
    "\n",
    "    ],\n",
    "\n",
    "}\n",
    "\n",
    "n = 1000\n",
    "embd_dim = 50\n",
    "sigmas = np.linspace(0.0, 0.35, 29)\n",
    "sigmas = np.array([np.format_float_positional(sigma, precision=4, unique=True, trim='0') for sigma in sigmas]).astype(float)\n",
    "seeds = [0, 1, 2]"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "7eb6e976d219f508"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "all_res = load_multiple_res(datasets=\"toy_circle\", distances=distances, root_path=root_path, n=n, embd_dims=embd_dim, sigmas=sigmas, seeds=seeds, n_threads=10)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "50c486c4e208e5ab"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# take square root of birth and death times to obtain the diagram corresponding to squared eff res\n",
    "for dist in list(all_res[\"eff_res\"])[:2]:\n",
    "    all_res[\"eff_res\"][dist+\"_sqrt\"] = transform_dgms(all_res[\"eff_res\"][dist], transformation=\"sqrt\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "c0743c50806d237e"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# filtering\n",
    "dob = 1.25\n",
    "all_res = filter_dgms(all_res, dob=dob, dim=1, binary=True)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "79b765b9ab56c43b"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# compute detection scores\n",
    "outlier_scores = compute_outlier_scores(dgms=all_res, n_features=1, dim=1)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "d146a77124f1e50b"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# plot figure\n",
    "fig, ax = plt.subplots(figsize=(5.5, 1.5))\n",
    "\n",
    "cmap = plt.get_cmap(\"tab20\")\n",
    "\n",
    "for i, distance in enumerate(outlier_scores):\n",
    "    \n",
    "    for j, dist_str in enumerate(outlier_scores[distance]):\n",
    "        mean = outlier_scores[distance][dist_str].mean(1)\n",
    "        std = outlier_scores[distance][dist_str].std(1)\n",
    "        color = cmap(j) if distance != \"euclidean\" else \"k\"\n",
    "        linestyle = \"dashed\" if distance == \"euclidean\" else \"solid\"\n",
    "        ax.plot(sigmas,\n",
    "                mean,\n",
    "                c=color,\n",
    "                clip_on=False,\n",
    "                label=full_dist_to_print[dist_str].replace(\"\\n\", \" \"),\n",
    "                linestyle=linestyle,\n",
    "                \n",
    "                )\n",
    "        ax.fill_between(sigmas,\n",
    "                        mean+std,\n",
    "                        mean-std,\n",
    "                        alpha=0.2,\n",
    "                        color=color,\n",
    "                        edgecolor=None)\n",
    "        \n",
    "ax.legend(frameon=False, loc = (1.05, 0.05), ncols=2)\n",
    "        \n",
    "ax.set_ylim(0, 1)\n",
    "ax.set_xlim(0, 0.35)\n",
    "ax.set_xlabel(\"Noise std $\\sigma$\")\n",
    "ax.set_ylabel(\"Loop detection score\")\n",
    "\n",
    "fig.savefig(os.path.join(fig_path, \"fig_toy_circle_eff_res.pdf\"))"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "146a85fdb838c6e8"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   },
   "id": "2dd8024f607df6c5"
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "conda-env-ph-py",
   "language": "python",
   "display_name": "Python [conda env:ph]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
