{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=1\n",
    "import matplotlib.pyplot as plt\n",
    "from utils.utils import get_path\n",
    "from utils.io_utils import load_multiple_res, dist_kwargs_to_str\n",
    "from utils.pd_utils import compute_outlier_scores, filter_dgms\n",
    "from utils.fig_utils import full_dist_to_print, dist_to_color\n",
    "import os\n",
    "import numpy as np"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "21d2b03e56ae5b06"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "style_file = \"utils.style\"\n",
    "plt.style.use(style_file)\n",
    "root_path = get_path(\"data\")\n",
    "fig_path = os.path.join(root_path, \"figures\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "58236ecb76fcd291"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# hyperparameters\n",
    "embd_dims = [2, 10, 20, 30, 40, 50]\n",
    "sigmas = np.linspace(0.0, 0.35, 29)\n",
    "sigmas = np.array([np.format_float_positional(sigma, precision=4, unique=True, trim='0') for sigma in sigmas]).astype(float)\n",
    "seeds = [0, 1, 2]\n",
    "n = 1000\n",
    "\n",
    "dataset = \"toy_circle\"\n",
    "\n",
    "distances = {\"euclidean\": [{}],\n",
    "             \"dtm\": [\n",
    "                 {\"k\": 4, \"p_dtm\": 2, \"p_radius\": 1},\n",
    "                    ],\n",
    "                 \"eff_res\": [\n",
    "                     {\"corrected\": True, \"weighted\": False, \"k\": 100, \"disconnect\": True},\n",
    "                 ],\n",
    "             \"diffusion\": [\n",
    "                 {\"k\": 100, \"t\": 8, \"kernel\": \"sknn\", \"include_self\": False},\n",
    "                ]\n",
    "             }\n",
    "\n",
    "dist_2_full_dist = {dist: dist + dist_kwargs_to_str(distances[dist][0]) for dist in distances}"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "e223bf0f26d58d2f"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# load all PH results\n",
    "all_res = load_multiple_res(datasets=dataset, \n",
    "                            distances=distances,\n",
    "                            root_path=root_path,\n",
    "                            n=n,\n",
    "                            seeds=seeds,\n",
    "                            sigmas=sigmas,\n",
    "                            embd_dims=embd_dims, \n",
    "                            n_threads=10)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "9228d694d43e1b25"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# filtering\n",
    "dob = 1.25\n",
    "all_res = filter_dgms(all_res, dim=1, dob=dob, binary=True)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "8e7b39366c5e5e79"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# comptue detection scores\n",
    "outlier_scores = compute_outlier_scores(all_res, n_features=1, dim=1)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "f59f22b68e73daff"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def highlight_cell(x,y, width, ax=None, **kwargs):\n",
    "    rect = plt.Rectangle((x-.5, y-.5), width,1, fill=False, **kwargs)\n",
    "    ax = ax or plt.gca()\n",
    "    ax.add_patch(rect)\n",
    "    return rect"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "d74de0f6a04c2fa5"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# plot figure\n",
    "letters = \"abcde\"\n",
    "\n",
    "sigma = 0.25\n",
    "\n",
    "titles = [\"Euclidean\", \"DTM\", \"Eff. resist.\", \"Diffusion\"]\n",
    "\n",
    "fig, ax = plt.subplots(ncols=len(titles)+1, figsize=(5.5, 1.5))\n",
    "\n",
    "for i, dist in enumerate(distances):\n",
    "    i+=1\n",
    "    cax = ax[i]\n",
    "    \n",
    "    # plot heat maps for each distance\n",
    "    ssigs, ddims = np.meshgrid(sigmas, embd_dims)\n",
    "    \n",
    "    means = np.stack([outlier_scores[embd_dim][dist].mean(2)[0] for embd_dim in embd_dims])\n",
    "    stds = np.stack([outlier_scores[embd_dim][dist].std(2)[0] for embd_dim in embd_dims])\n",
    "    im = ax[i].imshow(means.T, cmap=\"coolwarm\", aspect=1/3, vmin=0, vmax=1.0, interpolation=\"none\")\n",
    "    \n",
    "    \n",
    "    highlight_cell(x=0, y=np.where(sigmas == sigma)[0][0], width=len(embd_dims), ax=ax[i], color=\"black\", linewidth=0.5, linestyle=\"dashed\")\n",
    "    \n",
    "    \n",
    "    ax[i].set_xticks(np.arange(len(embd_dims)))\n",
    "    ax[i].set_xticklabels(embd_dims)\n",
    "    ax[i].set_yticks(np.arange(len(sigmas))[::4])\n",
    "    \n",
    "    if i==1:\n",
    "        ax[i].set_yticklabels([0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35])\n",
    "        ax[i].set_ylabel(\"Noise std $\\sigma$\")\n",
    "\n",
    "    else:\n",
    "        ax[i].set_yticklabels([])\n",
    "    \n",
    "    ax[i].spines[[\"top\", \"right\"]].set_visible(True)\n",
    "    \n",
    "    ax[i].set_xlabel(\"Ambient dim.\")\n",
    "\n",
    "    ax[i].set_title(titles[i-1])\n",
    "    ax[i].set_title(\n",
    "        letters[i],\n",
    "        loc=\"left\",\n",
    "        ha=\"right\",\n",
    "        fontweight=\"bold\",\n",
    ")\n",
    "    \n",
    "    # add the detection score for the selected sigma to the first panel\n",
    "    idx = np.where(sigmas == sigma)[0][0]\n",
    "    \n",
    "    if dist == \"euclidean\":\n",
    "        ax[0].plot(embd_dims, \n",
    "                 means[:, idx],\n",
    "                 color=dist_to_color[dist],\n",
    "                 label=full_dist_to_print[dist_2_full_dist[dist]],\n",
    "                 linestyle=\"dashed\"\n",
    "                 )\n",
    "    elif dist == \"eff_res\":\n",
    "        ax[0].plot(embd_dims, \n",
    "                 means[:, idx],\n",
    "                 color=dist_to_color[dist],\n",
    "                 label=f\"Eff. resist. $k={distances['eff_res'][0]['k']}$\"  \n",
    "                 )\n",
    "    else:\n",
    "        ax[0].plot(embd_dims, \n",
    "             means[:, idx],\n",
    "             color=dist_to_color[dist],\n",
    "             label=full_dist_to_print[dist_2_full_dist[dist]],\n",
    "             )\n",
    "    ax[0].fill_between(embd_dims, \n",
    "             means[:, idx] - stds[:, idx],\n",
    "             means[:, idx] + stds[:, idx],\n",
    "             color=dist_to_color[dist],\n",
    "             alpha=0.2,\n",
    "             edgecolor=None\n",
    "             )\n",
    "    ax[0].legend(loc='upper center',\n",
    "                     bbox_to_anchor=(0.5, -0.35),\n",
    "                     frameon=False\n",
    "              )\n",
    "\n",
    "# prettify the first panel\n",
    "ax[0].set_ylim(0, 1)\n",
    "ax[0].set_xlim(2, 50)\n",
    "ax[0].set_xticks(embd_dims)\n",
    "ax[0].set_xticklabels(embd_dims)\n",
    "ax[0].set_xlabel(\"Ambient dimension\")\n",
    "ax[0].set_ylabel(\"Loop detection score\")\n",
    "ax[0].set_title(\" $\\sigma = $0.25 \", va=\"top\")\n",
    "ax[0].set_title(\n",
    "    letters[0],\n",
    "    loc=\"left\",\n",
    "    ha=\"right\",\n",
    "    va=\"top\",\n",
    "    fontweight=\"bold\",\n",
    ")\n",
    "ax[0].legend(loc=(0.5, 0.25), frameon=False)\n",
    "ax[0].spines['left'].set_position(('outward', 5))\n",
    "\n",
    "\n",
    "fig.colorbar(im, ax=ax[1:5], label=\"Loop detection score\")\n",
    "\n",
    "fig.savefig(os.path.join(fig_path, \"fig_dims_dob.pdf\"))"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "918355bb6f1e9fde"
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "conda-env-ph-py",
   "language": "python",
   "display_name": "Python [conda env:ph]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
