{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=1\n",
    "import matplotlib.pyplot as plt\n",
    "from utils.utils import get_path\n",
    "from utils.toydata_utils import get_toy_data\n",
    "from utils.fig_utils import dataset_to_print\n",
    "from vis_utils.plot import plot_scatter\n",
    "import os\n",
    "from sklearn.decomposition import PCA"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "4ba068d8572ac10a"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "style_file = \"utils.style\"\n",
    "plt.style.use(style_file)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "d02cd99cd5de604e"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "root_path = get_path(\"data\")\n",
    "fig_path = os.path.join(root_path, \"figures\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "7a7907fbcbc1a396"
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Fig Noiseless toy datasets"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "54df52191a83c6ce"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# hyperparameters\n",
    "datasets = [\"toy_circle\", \"eyeglasses\", \"inter_circles\", \"toy_sphere\", \"torus\"]\n",
    "embd_dims= {\"toy_circle\": 2,\n",
    "            \"eyeglasses\": 2,\n",
    "            \"inter_circles\": 3,\n",
    "            \"toy_sphere\": 3,\n",
    "            \"torus\": 3}\n",
    "n = 1000\n",
    "data = {dataset: get_toy_data(n=n, dataset=dataset, seed=0, d=embd_dims[dataset]) for dataset in datasets}\n",
    "\n",
    "for dataset in data:\n",
    "    data[dataset] = PCA(n_components=embd_dims[dataset]).fit_transform(data[dataset])"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "466c7b8d805033bc"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# plot figure\n",
    "mosaic = \"abcde\"\n",
    "\n",
    "pads = [0, 100, 0, 0, 0]\n",
    "fig, ax = plt.subplot_mosaic(figsize=(5.5, 1.5), width_ratios=[0.4,1,0.8,0.8,0.8], mosaic=mosaic, per_subplot_kw={letter: {} if letter in [\"a\", \"b\"] else {\"projection\": \"3d\"} for letter in mosaic})\n",
    "\n",
    "for i, (l, dataset) in enumerate(zip(mosaic, datasets)):\n",
    "    # plot the data\n",
    "    plot_scatter(x=data[dataset], ax=ax[l], y=\"k\", s=1, alpha=0.5, scalebar=False, clip_on=False)\n",
    "    ax[l].axis(\"on\")\n",
    "\n",
    "    # set view for 3D datasets or spine position for 2D datasets\n",
    "    if embd_dims[dataset] == 3:\n",
    "        ax[l].view_init(45, 30)\n",
    "        ax[l].tick_params(pad=-5)\n",
    "    else:\n",
    "        ax[l].spines['left'].set_position(('outward', 3))\n",
    "        ax[l].spines['bottom'].set_position(('outward', 3))\n",
    "\n",
    "        ax[l].set_ylim(-1.0, 1.0)\n",
    "        ax[l].set_yticks([-1, 0, 1])\n",
    "    \n",
    "    ax[l].set_aspect(\"equal\", \"box\")\n",
    " \n",
    "    # position dataset names\n",
    "    if dataset ==\"eyeglasses\":\n",
    "        ax[l].text(0.2, 1.7, dataset_to_print[dataset], fontsize=7, transform=ax[l].transAxes)\n",
    "        ax[l].text(-0.05, 1.7, l, fontsize=7, transform=ax[l].transAxes, fontweight=\"bold\")\n",
    "\n",
    "    elif dataset != \"toy_circle\":\n",
    "        ax[l].text2D(0.2, 1.05, dataset_to_print[dataset], fontsize=7, transform=ax[l].transAxes)\n",
    "        ax[l].text2D(-0.05, 1.05, l, fontsize=7, transform=ax[l].transAxes, fontweight=\"bold\")\n",
    "    else:\n",
    "        ax[l].text(0.2, 1.7, dataset_to_print[dataset], fontsize=7, transform=ax[l].transAxes)\n",
    "        ax[l].text(-0.05, 1.7, l, fontsize=7, transform=ax[l].transAxes, fontweight=\"bold\")\n",
    "\n",
    "# reduce the padding between subplots\n",
    "fig.get_layout_engine().set(w_pad=4/ 72, h_pad=4 / 72, hspace=0,\n",
    "                            wspace=0)\n",
    "fig.savefig(os.path.join(fig_path, \"toy_data.pdf\"))"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "b2e8779cde684170"
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "conda-env-ph-py",
   "language": "python",
   "display_name": "Python [conda env:ph]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
