{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# imports\n",
    "from tueplots import bundles\n",
    "import wandb\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tueplots  import figsizes\n",
    "\n",
    "import sys\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "sys.path.insert(0, '.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "from analysis import sweep2df, BLUE, RED"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/Cellar/python@3.9/3.9.6/Frameworks/Python.framework/Versions/3.9/lib/python3.9/_collections_abc.py:940: MatplotlibDeprecationWarning: Support for setting the 'text.latex.preamble' or 'pgf.preamble' rcParam to a list of strings is deprecated since 3.3 and will be removed two minor releases later; set it to a single string instead.\n",
      "  self[key] = other[key]\n"
     ]
    }
   ],
   "source": [
    "plt.rcParams.update(bundles.neurips2022(usetex=True))\n",
    "plt.rcParams.update({\n",
    "    'text.latex.preamble': [r'\\usepackage{amsfonts}', # mathbb\n",
    "                            r'\\usepackage{amsmath}'] # boldsymbol\n",
    "})"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Data loading"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t Loading dsprites_sweep_7qg2q0aw.csv...\n"
     ]
    }
   ],
   "source": [
    "# Constants\n",
    "ENTITY = \"ima-vae\"\n",
    "PROJECT = \"dsprites\"\n",
    "SWEEP_ID = \"7qg2q0aw\"\n",
    "\n",
    "# W&B API\n",
    "api = wandb.Api(timeout=200)\n",
    "runs = api.runs(ENTITY + \"/\" + PROJECT)\n",
    "sweep = api.sweep(f\"{ENTITY}/{PROJECT}/{SWEEP_ID}\")\n",
    "filename = f\"dsprites_sweep_{SWEEP_ID}.csv\"\n",
    "\n",
    "\n",
    "runs_df = sweep2df(sweep.runs, filename, save=True, load=True)\n",
    "\n",
    "# runs_df = runs_df[runs_df.gamma_square == 1.]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Data pre-processing"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "\n",
    "mcc_history, cima_history = [], []\n",
    "for run in sweep.runs:\n",
    "    if run.name in runs_df[runs_df.gamma_square == 1.].name.tolist():\n",
    "        mcc_history.append(run.history(keys=[f\"Metrics/val/mcc\"])[f\"Metrics/val/mcc\"].tolist())\n",
    "        cima_history.append(run.history(keys=[f\"Metrics/val/cima\"])[f\"Metrics/val/cima\"].tolist())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "min_len = np.array([len(m) for m in mcc_history]).min()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "mcc = np.array([m[:min_len] for m in mcc_history])\n",
    "cima = np.array([m[:min_len] for m in cima_history])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Plot"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "outputs": [
    {
     "data": {
      "text/plain": "(5.5, 3.399186938124422)"
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "figsizes.neurips2022(nrows=1, ncols=1, rel_width=1)['figure.figsize']"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/rb/d8k1n6bj4lg801y0yxz4jtbh0000gn/T/ipykernel_75393/2777838377.py:40: UserWarning: FixedFormatter should only be used together with FixedLocator\n",
      "  ax.set_xticklabels(range(0, min_len * val_epoch_factor, val_epoch_factor))\n"
     ]
    },
    {
     "data": {
      "text/plain": "<Figure size 396x122.371 with 3 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPIAAACCCAYAAAB8SpVlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAmcUlEQVR4nO2deXiU1b3HP7+ZyUz2hBBIIgESloABKyQoYhGQ1opSpe4bKvdewKXculRbe7WotBdrtZbWrdpFrVrXetGqVdkiUhENESoECUggJJJgCNknySzn/vFOwkzWSTLvTCZ5P88zz8y7ne85k/zmvO9ZvkeUUhgYGIQ3plBnwMDAoP8YgWxgMAgwAtnAYBBgBLKBwSDACGQDg0GAEcgGBoMAS6gz4I2IKJOp698WpRQiEsQcDR3NUOmGk6bb7UYpFfw/jB8MqEA2mUy4XK4uj9fW1hIfHx/EHA0dzVDphpOmiLh1yE5ACKtb66qqKkNzkOkOFU29CatAjo2NNTQHme5Q0dSbAXVr3R37519Kc3MLyf96K6i6zc3NQdXrrealL2jvry/u//kDvazhrKk3Aa+RRcQiIr8VkfGebZuILBWRq0VkYn/SttsbA5PJXrBv374hoRkq3aGiqTcBr5GVUk4RqQFaW/fmAFVAHvA/wB3e54vIcmC55zN5eXk+6aWmppKZmUlDfQMoOhwHmDVrFsXFxaSnp1NRUcHhw4d9jo8aNYr09HSKi4vJyspiy5YtHdKYPXs2RUVFZGZmUlpaSllZGQD19fXk5eUxevRoUlJSKC0tJTMzk61bt3ZIY86cOezevZusrCyKi4spLy/3OZ6RkUFSUhIVFRWkp6ezbds2n+Miwty5c7Hb7TidToqKijh69KjPOePGjSMuLo6qqipSUlKornYCkJe3AwCLxcLs2bMpKChg2rRp7N69m2PHjrVdX109jaioKCor66ivrycpKYmCggKfstpsNmbNmsX27dvJzc1l586dHD9+3CcfkyZNwmw209TURFxcHDt37vQ5HhUVxcyZM9vSKCgooLa21uec7OxsHA4HZWVl2Gw2du3a5XM8NjaWGTNmtKWRn59PfX29zzlTp06lubkZk8mE2WymsLDQ53h8fDw5OTltaWzbtq2tnK2ceuqp1NXVERkZicvlYu/evT5pDBs2jFNPPZUBjVIq4C/gPmCC5/P3gQuABOAX3V1nMplUV+w7+xL12fTvdHlcLzZt2jSgNS95XnsF4vyBXtZQawIupUO8BOKlx621GUgB0kRkCbAOGAWcD7wQaD0DAwN9bq1dwE2ezY88738IRNoRlohAJNMrhg8fPiQ0A6HrcGkvEbCYwGwCUw/DJ7rTVArcCpxu7d3lBocbnC7t3eE6se1sd8zpPnGdW51ISyk42JxJbaFnHyfOSYiE7/WrFSd0hE2rNUBMTEzQNadMmTIkNNvr1jbBRwdh5TotgM4YA40OsHteTQ6vbaf27uxiuITZE9gmk/bZ7HmvawaYSuz6EwHrcvsGrz6Mhe0d98ZEQNGdemnqS1gFcvuGjmCwY8cOcnJygqZ36QtQV9fA+zfFBU0TtNpp7UdfUm7NZuNX8FmpFkwmTxDuOQpRERBlgXgbpMRq29EREBlx4pjVDC51IiBdbnAqcLs77v9nEThaHFx4stUn0C2emtxsOhH0Zq9jFjNEeL1HtNu2ePa1XmPyvPC8791TyJTsbEyi3T20HreGVTT4ElZZj40Lfkd+MIO4lbi44ARxfTNsOQgbv9JeR+qyAcgeCTeeAfPHQW66Fhh6sHoBgFWfxLshe2R20DX1JqwCuaamJuiaW7ZsYfbs2UHV1MqZELD0XG6obYaaJu12tq4ZLn8RPj2sPU/GWmFOJlzo3seyBRNJC+LNQCi+31Bo6k1YBbIKgVGg0+nsdxq9HX3Vm3I2tECVHe5ddyJYa5ug2vNe29z6LOpLVAQsOx3OHg8z0rVb4ry8MtLigtvaE4jvNxw09SasAtnAl09KYFeF9sz58k6t1TU+EhJsMCZRe5aNj4TEyBOf//AJxFjhH0tCnXuDQGIEcpiyrQSufQVsZshOgbeu9++6V/+tb74MQoMRyGHIp4dh8StwUrxWC1vNoc6RgTcikgukAmOVUk947f9vtMFRa5RS5V1d3xfCahpjsFpzvZkxY0bQNbsr52eeIE6Lg1evCWwQh6Ksg0VTRFaKyBoRWQNsUUq9A0SKyFTP8SRgEuACKgKtH1Y1ckuLI+iaFRUVQZ+/2tLSAkR12P9ZKVzzitaH+9pi7T2QhKKsYaYpIpLvtf20UuppAKXUKq+Toj0fXYDdc7wKWCEiS4FTgIA+5IRVIEdYgp/dpKSkoGtaIjqWM78UFr/sCeJrAh/EEJqyhpmmUkr5U50/KyLnoU2y+Moz52A7MB1tJuDuvmagK8IqkLvz89KLuro6hg0bFlRNl9MFnBhXvr0MrnkZRniCOFWnJ4xQlHUwaiqlPm63/azn4xd6aeox+ylXRBaKyM2e7cki8p8icreI9OvbszfZA5PJXnDgwIGgazY1NbV9LvAEcXKMvkEMoSnrUNHUGz1q5KVKqZtE5HbPg/4CYDNwBDgXeNn75KFgLFBdPQ2AgwerezQWgLm43C6cTidvbSvhJ1vGEB/h4PZxO9i7vRmHYSzgc45hLOAh0BOc0RoAAG4BxgMjgRXAU8CkbidHD1Jjgd5M/r/keaW+8+hx9XmZUpMfVmrW40qV1gQm7Z7OD6dJ/qHQZAAbC+hRI7c96ANnAWvRHu73KKX2dnehgYbdZeaql2BYlNY6PSqAts/+DhM1CC/0MBb4uJPdmwKRtjXC/5ky++dfCsCEja/3S3PkyJH9ur43NDnhaD0cboxlVIL2TNxTEAcyMINZ1qGmqTdh1WodFd2xb1VvsrKydNcor4PntsMLn2sTIKIjPEEcuAlQfhGMsg5VTb0Jq5FdjQ3Bt8PdvTvgXX5tfP41rHgTZj4Oj36szULKHgnjYhtID3IQg75lHeqaehNWNXJMbPCtfjprrezttERvHC5490v4U77WtRRngyW58B8zIGOYJ+3I0KyEEIqW2aGiqTdhUyM7K6uoKT/a84kB5sMPPwxIOg4X/P5fcMYTcPObUNUIvzgH8lfA/edoQdxKdXV1QDR7S6DKamgGBhFJ9vfcsKiRXTW1NO89gMkagXK5EHPwpvuofpoZNLbAgSo42gD5ZZoTx4PnwfzxPTtMBpv+ltXQDAyesdrLgQuB+f5cExY1sjkhHtuEsZga7VQ+/kyos+M3VY1w+d+goh5GxMCmZfDSVfDdCQMviA0GBiLyM+B+tDHZ5/p7XVjUyADmkcm4y8opv/tB4i/4HrbMMaHOUrd8XQtXvwQl1ZCVDMOjIWtEqHOlDw6Hg9LSUp+hpf6SkJDAnj17dMhV/zXT0tJITEzUP0O+/BHIBiaiTYHza8pf2ASyiOBKT8VSWkHpDT9l3Pt/C/pK9/6yvxKuelnzzHrhSnjko56vCWdKS0uJi4sjIyOj13+Turq6oM8z90fTbrdTVlYW9EBWSlUCm0VkG9qIyN/4c11Y3Fq3Ej98OGm/upv6DR9R9cwrQdGcOXNmr87f8TX84HloccLfF8OZY3uvGRcffAMF6H1ZW2lqamL48OF9+mENxaID/mhGRkbicAR//ruIxAAopZqBJ/29LqwCubm5meHLryFmzky+vmMVjq8D6pbSKaWlpX6fu7kYLntR61Jaex1MTe2bZqjW7+1NWdvT17sjzUQhuPij2Z+7vfZLC3v2+bu88D9F5BER+S3tJhh1R1gFstVqRUwmRj/9EKq5mdIf3q17C2RKSopf571VCNe9AmMTtSDO7Md8+d4MRQ0k/pa1L+yff2nbsFlvIiKCv56XHprtrH5uA7yXFoYTywu/B9zQTVKL0YY0v+L57Bc9BrLn1+Fyr+1EETnZX4FA4nBqtzq2ieNIvf8Oat96n5rX3u53ugvv3M3COzsf7VNVVdXj9c9th5vXwvST4O/X9t+9I1S+y/6UNdD4U9YdO3Z0WCdab80uEBHJ93otbz2glFqllLrV83qok2ttaA1XbfY/naGUKkFbS/xitOWJ/cKfGnmSUupVL6FqoEu7ky6MBW723Fb0a7k/71bREbcuI2rGqZT96B6clfr9Ax48eLDLY0ppDVn/8z58ZwL87SrN1bK/9KX1NxB0V1a98Oc2d9u2bR3mGYPmGPP8889TWVkZcM0uUEqpGV6vpzs7yXtpYc/2EvxcXlhE7kNbvfRN4AF/M+ZPq/XfO9nX3b9re2OB48BJwLtAdfuT+2MsYF5+BXE3/ZzDt66k+Z4VPsYCsZ7RUXl5eT0aCzidw9ompftjLOB2Z1FTU8t/PNPAuvJRnDWinGuT92Izn8UXXwTOWKCoqKhDTTSunbFAfn6+z/GejAUAJk6ciM1mC5ixgNvtpq6ujsq7VtPyxZeewoDZbMblcmM2m2j8XLvj2TvnorY0TCYTbrcbk8lE5KnZJP7yJz4aJpOprWHq2LFjrF+/no0bN7Js2TI+/fRTamtr2bBhA1dccQUVFRWsX7+eiIgILBYL06dPp7m5mZ07d1JTU8O+fftwuVycddZZnHbaaSxfvpwnnngCi8XCO++8w5w5cxgxYgRKqbY2iqamJvLy8nptLKB8lxb2tvrxZ3lhB/AXtBi7HnjYH01/Armz3s9x3ZzfOuyq9RbChna/fyXwNXDQ+2TPr9rTAGazWc2bN6/TRGNiY6iursbn+DwoP3yUil+sIfPqi4idPJnY2FjGjx/P/sTHAJjhdX52trZ4V3uNh97Z7XN84sSJTJyotUfk5eX5nD958mT4DL5xJfJleSI3zIR75qdiEq1l65RTTmk7b/LkyQA85mlDyshIBDTXis7yAfDoC2A2mbFYLGRnZ7flqT2tnlNdfV+ti8+15qczkpOTfdJoX9bc3Fyg57HJe/bsIS4ujlqrFZfFd9Sd2azd9LW2HZnbHccNYhJMJlOXXUI2m43bb7+dQ4cO0dTUREFBATk5OWzdupUHHngAq9VKSkoK33zzDRaLhXPPPZe9e/eybt06Hn74YUSENWvWMGPGDN58801iYmKYOXMmw4YNw+FwMGHCBPLy8rj66qsBrS0GtJbr6dOnd1t2HbADMWjP137fOvgTyJUi8hTwNpoj3AXAzm7Ob28sUABMA75Ch8a1kT/7b6r//i6lN93FpC82Ytax66asFt75Er4o19YGvmc+3HSGbnJhx6jf3t/lsa7mh/vTp1tSUsLq1at5+eWXqamp4fLLL+fDDz+kuLiYo0eP8sMf/pCWlhZqa2uZMmUK+/fvJycnh82bN7N582ays7M5cuQI27dvZ/78+djtdkpKSnA6nXzwwQcsXLiQ9957j4qKCl0b/PykEq2hDMDvf+YeA1kp9WcROQpcjVbb/p9S6qVuzu/MWCAgHr6tv5TemGw2Rv/pYfZ/exFH7lpN+hN+P1b4hTkhnT99Cv/4UrOkBW1B7InD9Qni1xfDl1+WA4mBT7wHUlP72F/WDyx+WByvXLmyw75Dhw4xduxYRo8ejd1uJykpqcN5q1evbvv84IMPtn222+0sWrQIgIULFwJwzz339Cn/gUREHgJ+qZSqEZFYpZTfszt6/BZF5BHgE6XUVZ7tRBE5WSkV3HF1QGRk58YCMTNzSL5lKZVr/kjiFRcSO3dWv3TK67Sa9+098GnpBECbJ/zTufD9k+En7/Yr+R7JzMzUVyAEul05tdhstj6ld8stt/Q5L33VDAKvKaVaa+MzgPX+XhjwVms9sTd2bSyQuupOrOPGcnj5nbgbe2+b6xALz+TDxc/DjEdh5Tqoa4HrJ5Wz+QZYtxR+9G0YFwQ/9aKiIv1FBohuKFroQ9Ur4Acni8jFIpJKL2NMj1Zr3ehuaJ05Jpr0p37NgXOuoPx+v4anAmB3wOHoNI5Zh3HPBzB5BPx4DlwwGSYkg9s9ElOQh81MmTIluIIh1I2KCr59Uyg0/UEp9ZyIXIfmOPt+b67151+0t63WulFdU93t8bj53yZp6dV888jTuGrruz0XoPAonPcXOGYdRnJzFZuWwYZlcNtsLYgBNm/eHICc945QaIZKt71P9WDV9Bel1F+VUouU1yqO/qBHq3VIOenX91D77gZumXE/pphoOnucVUobjbVqgzaAY3z9IeKcDWSN6Nd4FQODfiMilwG5gAJ2KqX8Gm/dY42slPozWhBfjdYX/IFSak3fs6ov5oR4Rj/5ILjcuGvrOXj5DdS+uwHlGZZX1Qj/+Trc/QF8OwPWL4U4Z0NoM21gcIISpdRdSqmfAfv9vcifVuvTlVL/AP7hte/MLrqZBgTx3/8upo3/RrU4qP9wKzV/fwdLWgoHrr+V+4ZfxbEWC/d9F5aedmKggoG+9MewcIiRJiIPoA0I+QzI7+F8wL9b649FpMRrW4AkIASGrf4jZjMSZSb7cD7H397II+ubeS76AtJKD/L4nt9zRtIM3JMuxJwQwGUcusD45zXwF6XUWrTVWRARvz1l/Ank04C5QBnaYBCniEzuQx77TXx87387vrZb+WH9AvLT4dKJdm47tJGmdf+m9KbXKbvtXhIuOg+VvAzaDx30MGtW//qk+0IoNAOhu3IdFFZ0fmy3Z39rzdyKUrGIQHYKrDqn82ufeeYZ3n//fV5+WXtcvPbaaxk3bhzLly/n0UcfZd68eeTn5yMi1NXVtW0nJydz4403dkgvFGYG/iAiFwLfQxuaKUAGcFF317Tiz8iuz4HPPX1bS0Wkmc67pHSnqZfLqlZHxPG9P4PLDY8tgoumRAFLUT/+L+z5O6l69lWqX34T98XXgNmEs+o4liTflV+Li4vbxkwHi1BohkrXrdyYpXtX1MzMTFJTU/nkk0+IjIxkzJgxZGZmcvfdd3PzzTdz+umns2DBApYsWeKz3dVsrubm5oHaBfUR8K5SygkgIn6vbdMbzy43MAZtsnMGcG8vrg0I/o7Icbg8fcO2JKYNg8d/4OsbLSJEnzaN6NOmcdJvViI/3YNqbKLkulvIfOtZxKvjOD09PcCl6JlQaAZCt6saFbp+Rna5wB9347vuuotly5Zx/fXXt01ocTqdPnOL229nZGR0mlZnQ30HCH8C9nmekW8DXhSRNKVUj71E/hgLjBeRJ4FtQB3wLeDB7q/Sh5YW/zyUnvwEjtmSGNlUyf9d5xvE7TFFRmKyWpGoSOr+uZGjq3/vc7yiouO9on3nbuw79Vt2pDPNYBAKXX98sXbt2sWxY8eYNWsWOTk5HDhwgP379/PLX/6Sv/3tb7zxxhts3LiRG2+80We7K+uiUHhx+ckR4Dm0yUbjlFL78HOElz818l7gE+BXQBOwCPg2njnE7RGRXCAVGKuUesLjJrIUbTTYR/72i3VGc3PPQ+uKq+B3/4KElhpOsldgNfds1v+77feilMK2+BLK7/sNUadNI/7ceQAcPnyY8ePHd59AgAmFZqh0HQ4HkZHdDxRcsWIFcGLk2apVq9qOPfbYYz7nzp49OyCa/UFELMBDwGNKqa+89i8GpgLPdTFXoRpYAjQCMSKShBbUf+5J059AnqeU8pmJ75kN1RU+xgJKqV3Aj0XkGuD//NDrM0ppbh0WE6Q39s6YT0RIf/JX2HfupmTxCrLy38M6tvNbzd9tb32q6N+SrUOJQdJyLyLi3R30dKtLiIisROvNAa1h2MezS0Qi0Kbz2oADXaT/e+B0NF+vTOA6YJc/GfOnsWtLJ/v+2c0l7Y0FWrF6LD596I9DSCuzZs2iuLiYzxsy2FwcyfWZRWwva8KJfw4hsdXVJCQk8OXBYkY9/ygHZ/+AL867iro1P6e+paWDQ4jJraitremQlzlz5rB7d0eHkNjbfwlA8ht/6tEhZO7cudjt9rB0CGlfltjYWBoaGoiJiaGhoQG32+1zTmRkJEopWlpaEJEOkxlaHUL8SaNVs6c06uvrcblcPvmNiorC5XJhMpl6cghRSqlOb3WVUqu8tz2WPd7HHcAdIvJdtBUk3uokjaMisg8Y7bmtXtOZVqcopQL6As4EzkMz117i2XcG2iyqbq81mUyqKy55XqnvPHq8y+NVjUp967dKLXxGKadLqfPv2KXOv2NXl+d7s+/sS9S+sy9p265e+57aYRqlSm74idq0aVOP5/c2/Z7oTDMY9FW3sLBQNTY2Krfb3etra2tr+6TZH/zRdDqdqrCw0Gcf4FL+xYAZzZP6LM/2EiAdzT3zYiC2i+vuB15Fc8y5wR+t1lfAV5pQnYz4Ukp9Eoi0rbauWxsf2ATH7fDiVWDu52ylhEXnMvKnP+Tog4+TkpkOXdjp6MWoUaOCqtdf3bS0NMrKyvrUiORwOIJuieuvZlJS3+asqq49u57q4dJ/K6XuBRARvxZvayVsloyBrrufPj0ML+6AG2fC1AA5taSuupPGz3bQsGoN9nPnEzUteFP8wq37KTExsc9Lq9jt9qD36YZC00/iPX3H6Whumxv9vTCsDOqb7B1brVtc8NN/wqh4+PFZgdMSi4UxLz6Oio/j4GXLcR6vDlziPVBcXBw0rVDrDhVNP/kE+B1aH/KzvbkwrGrk6OjoDvv+8AkUVcKzl0G01513IFqWI0Ymk/naUxz87pUcvv5WMtb+xWewiF5kZWXprjFQdIeKZneIyDOcmLvQ2tL9FFo3r1+EVY1cU1Pjs93aZ7xwMpzT3Wo6/aDAaeek36yk9p31HP3VYz1f0E/2z7+UHacv0F2nM9q35huaQeM1tFr4NuBWz+tXvUkgrGpkb7z7jLsbGhgIht+8hIat2ylf+RDRp03TV8xgyKGU6sz/4lBv0girGtmbtYXa6oc/mwepOq9CKiKkP/VrIrOzOLR4Be6mnldLVG43jrIj1G/5FOc3x1CO0KznZDA0CMsa+bgd7lunLZp2bU5wNM0x0WS8/keKTj+f5sIiIqdNwXm8mpbiw7QUl9By8DAtB0pOfD5YivJeHjUiguOvvEXi5RcM2AXaDcKXsAzkQPYZ9wZb1jhG/+URDl22nMZ/5bM7earPcfOwBKyZY4icOpn4C87BmjEGa+ZojvxsNS3FJZRcfTPVL77BqMdXYx19UkDy1NUKDgZDi7AK5ISEBF36jLsLgvaD8BMvPp8jk8bjrqtnxG3LsWaMxpo5BlvmaMyJnRsfHP3140QmTiVx0QLKf/5r9k49m7TVdzH8xuuQTubwJSSExnzFnwkHhubAJKyekesbGnXpM+6OzkzbI1JHYJuYycjbbyDx4vOJnj61yyBuRUQYcesysv69gZgzZ1D2o5+zf+7FNO3e2+Fcex8M9gNBKAzqh4qm3oRVIFc7oyiqhP8917fPWE8CvYyKLXMMme++wJjnfkdz0QGKchdQft/DuL2ep/WcYtcdoViqZqho6k3YBLLdAWW14nef8YSNrwfkubGryen9QUQYtvgSJu/OI/HyC6j4xRqKchfQ8K/PANpm3wQbPcpqaAaHgAeyiOSKyEIRudlr3wLP/j5VNUpB8XFtyIvefcbtaV3wXA8sI4Yz5q+/J/PdF3A32tk/5yKa9xXTbO96jSs96W9Z98+/tK3xzd/za664qecTA4yef9NQoUdjl4+xABAPzAKOA31awbG+RQvmEZF2UuM6DtMMd+LPncekf2+g/N6HqVzzRywV3/DVdy/HOj4D24QMrOPGYpuQiXX8WMyx/XOA7G0rd7i2iodrvvuKHoHc3ljgHLQhaNPR5in7OHD6aywwJsaBw+EgL+/TDoKtxgLp6elUVFRw+PBhn+M9GQuA1pJZVFREZmYmpaWlbb/arZPt+2UsUF0NwMGDB7s3FrhwHqkfbqWh+BCuRjtVb7wDVb7DUmXEcKzjxuA6KYX4KZOoKS4BawQfvrEWlRiPxWrt1lggtrqaqKhoKisr/TIWSAAa6hs6lNXbWCAuLo6dO3e2lTMvL4+oqChmzpzZZk5QUFBAbW2tTxopLQ6UW1FWVobNZmPXLl8zjNjYWGbMmNGWRn5+fod1m6ZOnUpzczMmkwmz2UxhYaHPd15QUEBOTk5bGtu2bWsrZyunnnoqdXV1REZG4nK52LvXtwHSy1jAL0RkLpqt7Ril1LWefTbgWjQbn8+UZhwQMMQzoTlwCYqciTYAfDxQD+QB3wecwN+VUt90da3ZbFYul6vTY5e+ANXV1axfkRjQ/PZEXl4e89rNR9azVts//1Kqq6uZUaAtjeuqqaX5q0O0fHWI5q8O0rK/mOYDh2jZfxBHma+dkVitRIxOwzpmFBFj0rGOOYmI0aOwjk3X9o1O48DCa7vMS3/L2pfvxbusgaS7vPS1nCKigAKvXV1a/SilHhKRB5VSP/UcPweIQ4uH/1FK3dH7UnVNUIwFgIDMNgjFAtWjR4/udxq9vb2z2U40JZgT4onOOYXonFM6nOdutLP/7EtwNzWTfMO1tJSU4Sj5mpaSUurXb8bxdYX2TOJNhAVTVBSHl99JZHYWtpMnEpmdRUR6WkDK2lu8y9oTgbpd7kc5lfLf6mcq8BevXTbAQUcLrIAQVgNCQuFHnJISoFEnvcBq9c8xwxQdhSkmGlNMNMk3L+lw3N3SgqOsHEdJmSfIy6h84jncdjs1a9+j6s8vnUgrNoaIrHGUTJ1MZPZEIk/OwnbyBJRSARlSqpSi5cAh7AW7sH/+BY3bv6Dh43zEbKZ0xd3EnTOH2LPPxByv88B5uv6bqnZ+YH3F45Z5OlAtIo+jPVK+BPwHmmHAC91c3ifCKpC1bpngNnaVlpYGffWFQHU/maxWbJljsGWOadtXt+EjQKvVnN8co6mwiKbCfTTtKaJq+05c6zZz/K+vnUhEBLFZ2T//UiJGpRExKpWI9DSso9KISNe2LSkjfEaoKZeL5n3F2Au+wF7wBY2ff4H98924a7RnZImIIHLqJCzDh+G0N3H8r69x7MnnwGIhZlYucefMIe6cOUTlfqvTkW+9RSmFq7oGR+kRHKVHOJz/OcNdipbD2raj7AjNRQcwRQfGNUQp9QK+wfqs5/0PARHohLAK5FAMlAjF4IHIyODY0FhGDCd27ixi52prPo1obsZms+GqrqFpz36aCouoWPUI7uYWcLlp/DgfR1k5qr03l9lMxEkpuDyBumvYybgbtC40sdmIOvVkhl25iKjpU4nK/RaRU7Iw2Wzaj4NbMeGDl2jcup26Dz6kbt1myu99mPKVD2EelkDsd87SAvt7c30kldOJ89hxnN8cw1VZhbOyCuc3VbiOVWk/UHv2oVocfJk9F0fpkbb8tFIhgiVtJNb0NCInT8BVWxewQA4FYRXIWqtnYlA1t27d2qFhRG9qa2t6PkkHWstqTkwgZlYuMbNyOf6i1snQ+lyq3G6clVWemqy8rUZzlJVT89b7oBRJ/3klUdNPISpnKpGTJyDdGN3V1tZgslrbflDS/vcunN8co27DFurWbabugzxqXn8bAIm0gQi7kqfgOt71d2SKj0O1tHhq/snELThbu4tIP4mI9DTyDx9kzsWLfPLVm/7vgUhYBfJAIZz7JvubdzGZiBiZTMTIZGjXANcaDKPWrOrsUr+xjBjOsCsXMezKRZrPdGERdes2U/HA70FMJF72fSwjhmNJTsI8PMnzeZi2nZyEyWpty0vGqx2NK1Vefbc/LuGIEcgGAxoRIXLKJCKnTNJqfCD90f8NuE44/ziDEcgGIWTCxtc7XTnEoPcYgRzm6F2ThHtNNVQIq0AOxYT7OXPmBFVvwsbXO6xvFCyCXdahpKk3YTONEaChsSHomrt367cO8kDSDJVubzQDNTU1VN+vnoRVjRwdFfyZT0PJQH0wlbW7gB9oBvWBIKxq5PZLZgaDobSkyVAp6wBeMqbPhE0gv74YfpLVcQqj3rRORRzsmoHQ7cut71D6fvUk4LfWIpILpAJjlVJPePb9NzAKWKOUGnzfooFBiAmGQ8jXwCSgBqhof7K/xgJFRUUopTrtdwy2sUBmZiZbt27tkEZXxgKtZGRkdG8sIMLcuXOx2+04nU6Kioo4evSozznjxo0jLi6OqqoqUlJSyM/P9zlusVi6NRYAmDhxIjabzS9jgdzcXHbu3Mnx48d90ujMWMAbf4wFsrOzcTgcuhgLtBIfHx90Y4GQ0JtV0f1crf1pz/stwHiv/UuBb3V3rclk6nYV+U2bNnV7XA+GimaodMNJE3CpAMdLoF7BcAjZjmbzUw+8qbTV3Lu6VgHddaIKENgM98xQ0QyVbjhpmpRSA3K9n4AHsp6ISL7qwqHB0AxP3aGiqTdh02ptYGDQNUYgGxgMAsItkJ82NAed7lDR1JWwekY2MDDonHCrkQ0MDDrBCGQDg0HAgJ795L30BvAkcCHwkVLqHZ11o4B7gCa0kWkVQJXq3Hw/UJo/AGYC2cAbwMnAW3ppiogFeAht8YBZQBkQA+wCzgCGA39USrXopHkekAusV0q9qOcw3na6KXj+j4Dd6FTWYDPQa+SPlVJ3owVTLjASKOz+koDwLbQ/OECCUupttJFpuqGUWquU+hnwJpCBtvxIx1XQA6fnRBs2Gw3kKKU2AZcAy4APgCNoP6J6aJrQPJ5/DIwSkSS0YbwuOhnGG0Bdwff/SLeyBpsBHchKKUfr0htKqUfR/vBLgqC7TSm1FJiPtswHgO6uBiKSCnytlLof+DVwjd6aaP8DraPt6tF5aRMPypP+mcBvlVJVSqkVQDHQcW2cQAr7/h8Fo6xBYaDfWrcuvVEjIvXA52hLb+it+320VSUfAMaJyPnAq3rroq1c+ZaI/Ag4gLaKpS6IiBntriMeKBSROcAGtNvN84ERwJ900kwDbkOrfceKyGa0YbxVHv2A4q3r+dsWov0fudCprMHG6H4yMBgEDOhbawMDA/8wAjnMEJHTRaRYRH4gIheJSK+XdRCR4SKS3/OZBuHCgH5GNuiIUupTEalRSq0FEJF1fUjjmKfNwWCQYNTI4UmCp0a+G8gVkX+JyE0issVTY08SkRUiskxEVgCIyH+JyEIReUNETJ59l4nIOyIyMaSlMeg3Ro0cntQopdaKyNtADppzxZMiUgTcDEQAP1FKlYnIpyLyEZCllPqziOxRSrlFBKXUayIShzYoYl/oimPQX4waOYzxDHTYiebIAlAOVKP9XUd69n0NODnRP1vnCd5W3GhdbQZhjFEjhxkicjowUkQuQhupdA5gEpELgcnAb9CGWt4iIhuA15RSuz0NZO8D/wBeBFJEZDIw0ZOOQRhj9CMPAkQkTyk1L9T5MAgdxq11mCMio4F0EUkPdV4MQodRIxsYDAKMGtnAYBBgBLKBwSDACGQDg0GAEcgGBoMAI5ANDAYBRiAbGAwC/h/O/et/5HJFEwAAAABJRU5ErkJggg==\n"
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "LABELPAD = 1\n",
    "TICK_PADDING = 2\n",
    "\n",
    "fig = plt.figure(figsize=figsizes.neurips2022(nrows=1, ncols=2, rel_width=1)['figure.figsize'])\n",
    "\n",
    "\"\"\"MCC vs CIMA over different gamma\"\"\"\n",
    "ax = fig.add_subplot(121)\n",
    "ax.grid(True, which=\"both\", ls=\"-.\")\n",
    "\n",
    "# create double y-axis\n",
    "ax_cima = ax.twinx()\n",
    "ax_mcc = ax.twinx()\n",
    "\n",
    "# Link the respective y-axes for grid and plot\n",
    "ax.get_shared_y_axes().join(ax, ax_mcc)\n",
    "\n",
    "# Remove ticks and labels and set which side to label\n",
    "ticksoff = dict(labelleft=False, labelright=False, left=False, right=False)\n",
    "ax.tick_params(axis=\"y\", **ticksoff)\n",
    "ax_mcc.tick_params(axis=\"y\", labelleft=True, labelright=False, left=True, right=False)\n",
    "ax_cima.tick_params(axis=\"y\", labelleft=False, labelright=True, left=False, right=True)\n",
    "\n",
    "# MCC\n",
    "ax_mcc.errorbar(range(min_len), mcc.mean(0), yerr=mcc.std(0), label='mcc', c=BLUE)\n",
    "\n",
    "# CIMA\n",
    "ax_cima.errorbar(range(min_len), np.log10(cima).mean(0), yerr=np.log10(cima).std(0), label='cima', c=RED)\n",
    "\n",
    "\n",
    "# set z-order to make CIMA the top plot\n",
    "# https://stackoverflow.com/a/30506077/16912032\n",
    "# ax.set_zorder(ax.get_zorder()+1)\n",
    "# ax.set_frame_on(False)\n",
    "\n",
    "ax_cima.set_ylabel(\"$\\log_{10} c_{\\mathrm{IMA}}$\", labelpad=LABELPAD)\n",
    "ax.set_ylabel(\"$\\mathrm{MCC}$\", labelpad=LABELPAD+17)\n",
    "\n",
    "ax.set_xlabel(\"Epoch\", labelpad=LABELPAD)\n",
    "val_epoch_factor = 25\n",
    "ax.set_xticklabels(range(0, min_len * val_epoch_factor, val_epoch_factor))\n",
    "ax.grid(True, which=\"both\", ls=\"-.\")\n",
    "\n",
    "handle1, label1 = ax_mcc.get_legend_handles_labels()\n",
    "handle2, label2 = ax_cima.get_legend_handles_labels()\n",
    "\n",
    "plt.legend([handle2[0], handle1[0]],[\"$\\log_{10} c_{\\mathrm{IMA}}$\", \"$\\mathrm{MCC}$\"], loc='center right')\n",
    "\n",
    "plt.savefig(\"dsprites_mcc_cima.svg\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "outputs": [
    {
     "data": {
      "text/plain": "(5.5, 3.399186938124422)"
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "figsizes.neurips2022(nrows=1, ncols=1)['figure.figsize']"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}