{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=1\n",
    "import matplotlib.pyplot as plt\n",
    "from utils.utils import get_path\n",
    "from utils.pd_utils import get_persistent_feature_id\n",
    "from utils.io_utils import load_multiple_res, dist_kwargs_to_str\n",
    "from utils.dist_utils import get_dist\n",
    "from utils.toydata_utils import get_toy_data\n",
    "from vis_utils.plot import plot_scatter\n",
    "import os\n",
    "import numpy as np\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import MDS"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "a22db7cc1cfe38e0"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "style_file = \"utils.style\"\n",
    "plt.style.use(style_file)\n",
    "root_path = get_path(\"data\")\n",
    "fig_path = os.path.join(root_path, \"figures\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "fc700f21e136e201"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Visualization of distances and  PDs with varying ambient dimension \n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "dd5928aadeb71bc9"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "sigma = 0.25\n",
    "seed = 2\n",
    "n = 1000\n",
    "embd_dim = 50\n",
    "x = get_toy_data(dataset=\"toy_circle\", d=embd_dim, n=n, sigma=sigma, seed=seed, gaussian={\"sigma\": sigma})\n",
    "x_pca = PCA(2).fit_transform(x)\n",
    "\n",
    "d_eucl = get_dist(x, distance=\"euclidean\")\n",
    "d_eff = get_dist(x, distance=\"eff_res\", k=100, weighted=True, disconnect=True, corrected=True)\n",
    "d_diff = get_dist(x, distance=\"diffusion\", k=100, t=8, kernel=\"sknn\", include_self=False)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "f80657a92d495705"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "pt_id = 2\n",
    "sigma = 0.25\n",
    "seed = 0\n",
    "n = 1000\n",
    "embd_dim = 50\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "84f009cc3c29a3c1"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "x = get_toy_data(dataset=\"toy_circle\", d=embd_dim, n=n, sigma=sigma, seed=seed, gaussian={\"sigma\": sigma})\n",
    "\n",
    "d_eucl = get_dist(x, distance=\"euclidean\")\n",
    "d_eff = get_dist(x, distance=\"eff_res\", k=100, weighted=True, disconnect=True, corrected=True)\n",
    "d_diff = get_dist(x, distance=\"diffusion\", k=100, t=8, kernel=\"sknn\", include_self=False)\n",
    "\n",
    "mds = MDS(2, eps=1e-6, max_iter=6000, dissimilarity=\"precomputed\")\n",
    "\n",
    "mds_eucl = mds.fit_transform(d_eucl)\n",
    "mds_eff = mds.fit_transform(d_eff)\n",
    "mds_diff = mds.fit_transform(d_diff)\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "20b3a3e64b03ec0d"
  },
  {
   "cell_type": "markdown",
   "source": [],
   "metadata": {
    "collapsed": false
   },
   "id": "f7feddd52c15ecf6"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# set hyperparameters\n",
    "embd_dims = [2, 10, 20, 30, 40, 50]\n",
    "sigmas = np.linspace(0.0, 0.35, 29)\n",
    "sigmas = np.array([np.format_float_positional(sigma, precision=4, unique=True, trim='0') for sigma in sigmas]).astype(float)\n",
    "seeds = [0, 1, 2]\n",
    "n = 1000\n",
    "\n",
    "dataset = \"toy_circle\"\n",
    "\n",
    "distances = {\"euclidean\": [{}],\n",
    "             }\n",
    "\n",
    "dist_2_full_dist = {dist: dist + dist_kwargs_to_str(distances[dist][0]) for dist in distances}"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "237eceeb5b034e05"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# load PH results\n",
    "all_res = load_multiple_res(datasets=dataset, \n",
    "                            distances=distances,\n",
    "                            root_path=root_path,\n",
    "                            n=n,\n",
    "                            seeds=seeds,\n",
    "                            sigmas=sigmas,\n",
    "                            embd_dims=embd_dims, \n",
    "                            n_threads=10)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "d2a3898c710f2b1"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# plot figure\n",
    "\n",
    "selected_sigmas = [0.1, 0.2, 0.25]\n",
    "selected_dims = [2, 20, 50]\n",
    "seed = 0\n",
    "\n",
    "fig, ax = plt.subplots(ncols = len(selected_sigmas) + 3, width_ratios=(1,1,1,1., 1., 1.), figsize=(5.5, 1.15))\n",
    "\n",
    "cmap = plt.get_cmap(\"tab10\")\n",
    "size = 7\n",
    "lims = []\n",
    "\n",
    "letters = \"abcdef\"\n",
    "# plot persistence diagrams for different dimensions and noise levels\n",
    "for i, d in enumerate(selected_dims):\n",
    "    for j, sigma in sorted(enumerate(selected_sigmas), reverse=True):\n",
    "        res = all_res[d][\"euclidean\"][\"euclidean\"][sigma][seed]\n",
    "\n",
    "        if d == 50 and sigma == 0.25:\n",
    "            ind = get_persistent_feature_id(res, dim=1, m=2) # correct loop is not most persistent\n",
    "        else:\n",
    "            ind = get_persistent_feature_id(res, dim=1)\n",
    "\n",
    "        mask = np.ones(len(res[\"dgms\"][1]), dtype=bool)\n",
    "        mask[ind] = False\n",
    "        if i == 2:\n",
    "            ax[i].scatter(*res[\"dgms\"][1][mask].T,\n",
    "                       color=cmap(j),\n",
    "                       s=size,\n",
    "                       alpha=1,\n",
    "                       label=f\"$\\sigma = {sigma}$\",\n",
    "                       edgecolor=\"none\"\n",
    "                       )\n",
    "        else:\n",
    "            ax[i].scatter(*res[\"dgms\"][1][mask].T,\n",
    "                       color=cmap(j),\n",
    "                       s=size,\n",
    "                       alpha=1,\n",
    "                        edgecolor=\"none\"\n",
    "                       )\n",
    "        # mark feature of correct loop\n",
    "        ax[i].scatter(*res[\"dgms\"][1][ind].T,\n",
    "                       color=cmap(j),\n",
    "                       s=size+1,\n",
    "                      marker=\"X\",\n",
    "                      edgecolor=\"k\",\n",
    "                      linewidth=0.5,\n",
    "                       alpha=1,\n",
    "                       )\n",
    "    # prettify the plots\n",
    "    xlim = ax[i].get_xlim()[1]\n",
    "    ylim = ax[i].get_ylim()[1]\n",
    "    lim = max(xlim, ylim)\n",
    "    lims.append(lim)\n",
    "    \n",
    "    ax[i].set_xlabel(\"Birth\")\n",
    "    if i == 0:\n",
    "        ax[i].set_ylabel(\"Death\")\n",
    "        ax[i].set_xticks([0, 1, 2])\n",
    "    else:\n",
    "        ax[i].set_xticks([0, 1, 2])\n",
    "        ax[i].set_yticklabels([])\n",
    "    ax[i].set_aspect(\"equal\", \"box\")\n",
    "    \n",
    "    title = ax[i].set_title(f\"Euclidean\\n$d = {d}$\", va=\"top\")\n",
    "    ax[i].set_title(\n",
    "        letters[i]+\"\\n\",\n",
    "        loc=\"left\",\n",
    "        ha=\"right\",\n",
    "        va=\"top\",\n",
    "        fontweight=\"bold\",\n",
    "    )\n",
    "\n",
    "ax[2].legend(loc=(0.6, 0.02),\n",
    "             frameon=False, \n",
    "             ncols=1,\n",
    "             handletextpad=0.1,\n",
    "             markerscale=2.0\n",
    "        )\n",
    "\n",
    "    \n",
    "lim = np.max(lims)\n",
    "for i in range(len(selected_sigmas)):\n",
    "    \n",
    "    ax[i].plot([0, lim], [0, lim], \"--\", c=\"k\", linewidth=1)\n",
    "\n",
    "    ax[i].set_xlim(0, lim)\n",
    "    ax[i].set_ylim(0, lim)\n",
    "    \n",
    "\n",
    "\n",
    "# plot MDS embeddings\n",
    "\n",
    "pt_id=3  # id for reference point\n",
    "\n",
    "# get angles to rotate reference point to top\n",
    "theta_eucl = np.arccos((mds_eucl[pt_id] / np.linalg.norm(mds_eucl[pt_id])).dot(np.array([0, 1])))\n",
    "rot_eucl = np.array([[np.cos(theta_eucl), -np.sin(theta_eucl)], \n",
    "                     [np.sin(theta_eucl), np.cos(theta_eucl)]])\n",
    "\n",
    "\n",
    "theta_eff = - np.arccos((mds_eff[pt_id] / np.linalg.norm(mds_eff[pt_id])).dot(np.array([0, 1])))\n",
    "rot_eff = np.array([[np.cos(theta_eff), -np.sin(theta_eff)], \n",
    "                     [np.sin(theta_eff), np.cos(theta_eff)]])\n",
    "\n",
    "theta_diff = np.arccos((mds_diff[pt_id] / np.linalg.norm(mds_diff[pt_id])).dot(np.array([0, 1])))\n",
    "rot_diff = np.array([[np.cos(theta_diff), -np.sin(theta_diff)], \n",
    "                     [np.sin(theta_diff), np.cos(theta_diff)]])\n",
    "\n",
    "\n",
    "# plot MDS scatter plots\n",
    "plot_scatter(ax=ax[3], x=mds_eucl@rot_eucl, y=d_eucl[pt_id], alpha=1, s=1, cmap=\"viridis\", scalebar=False)\n",
    "plot_scatter(ax=ax[4], x=mds_eff@rot_eff, y=d_eff[pt_id], alpha=1, s=1, cmap=\"viridis\", scalebar=False)\n",
    "scatter = plot_scatter(ax=ax[5], x=mds_diff@rot_diff, y=d_diff[pt_id], alpha=1, s=1, cmap=\"viridis\", scalebar=False)\n",
    "\n",
    "# plot reference point\n",
    "ax[0+3].scatter(*(mds_eucl[pt_id]@rot_eucl).T, c=\"k\")\n",
    "ax[1+3].scatter(*(mds_eff[pt_id]@rot_eff).T, c=\"k\")\n",
    "ax[2+3].scatter(*(mds_diff[pt_id]@rot_diff).T, c=\"k\")\n",
    "\n",
    "# prettify MDS embeddings\n",
    "ax[0+3].set_title(\"Euclidean\", )\n",
    "ax[1+3].set_title(\"Eff. Resist.\",)\n",
    "ax[2+3].set_title(\"Diffusion\",)\n",
    "\n",
    "\n",
    "txt = ax[1+3].text(x=-0.2, y=-0., s=\"MDS ($d=50, \\sigma=0.25$)\", transform=ax[1+3].transAxes, va=\"top\")\n",
    "txt.set_in_layout(False)\n",
    "\n",
    "ax[0+3].set_aspect(\"equal\", \"datalim\")\n",
    "ax[1+3].set_aspect(\"equal\", \"datalim\")\n",
    "ax[2+3].set_aspect(\"equal\", \"datalim\")\n",
    "\n",
    "for i in range(3):\n",
    "    i+=3\n",
    "    ax[i].set_title(\n",
    "        letters[i],\n",
    "        ha=\"right\",\n",
    "        loc=\"left\",\n",
    "        fontweight=\"bold\",\n",
    "    )\n",
    "\n",
    "cbar= fig.colorbar(scatter, ax=ax, shrink=0.8, pad=0.02 )\n",
    "cbar.ax.set_yticks([0, d_diff[pt_id].max()]) \n",
    "cbar.ax.set_yticklabels([\"0\", \"$d_{max}$\"]) \n",
    "\n",
    "fig.savefig(os.path.join(fig_path, \"fig_pd_by_dim_mds.pdf\"))"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "b073e2f4b2b35948"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# note that title alignment is different in the pdf than displayed here."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "9c5c9a375cd7cd93"
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "conda-env-ph-py",
   "language": "python",
   "display_name": "Python [conda env:ph]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
