{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "\n",
    "# Run this cell to start the Demo GUI\n",
    "# When switching from one file to another, the loading time can become a few seconds.\n",
    "root_path = \"./Examples/\"\n",
    "\n",
    "%matplotlib widget\n",
    "\n",
    "if True:\n",
    "    import sys\n",
    "    import os\n",
    "    from pathlib import Path\n",
    "    sys.path.append(str(Path(os.path.realpath(\"dummy.ipynb\")).parent))\n",
    "    sys.path.append(os.path.join(str(Path(os.path.realpath(\"dummy.ipynb\")).parent), \"FactorRotations\", \"utils\"))\n",
    "\n",
    "    from scipy.stats import norm\n",
    "    from ipywidgets import widgets \n",
    "\n",
    "    import matplotlib.pyplot as plt\n",
    "    import warnings\n",
    "    warnings.filterwarnings( \"ignore\", module = \"matplotlib\\..*\" )\n",
    "    import numpy as np\n",
    "    np.seterr(all=\"ignore\")\n",
    "    from demo_utils.datahandler import DataHandler\n",
    "    \n",
    "    pkl_samples = np.sort([p for p in os.listdir(root_path) if p.endswith(\".pkl\")])\n",
    "\n",
    "\n",
    "    for pkl_sample in pkl_samples:\n",
    "        _ = DataHandler(os.path.join(root_path, pkl_sample))\n",
    "\n",
    "    data = DataHandler(os.path.join(root_path, pkl_samples[0]))\n",
    "    \n",
    "    plt.ioff()\n",
    "    sample_fig = plt.figure(1,figsize=(3,6))\n",
    "    sample_fig_ax = [sample_fig.add_subplot(2,1,1), sample_fig.add_subplot(2,1,2)]\n",
    "    sample_fig.canvas.resizable = False\n",
    "    sample_fig_ax[0].set_title(\"Mean Prediction\")\n",
    "    sample_fig_ax[1].set_title(\"Factor Model Prediction\")\n",
    "\n",
    "\n",
    "    for a in sample_fig_ax:\n",
    "        a.axis('off')\n",
    "        a.set_xticklabels([])\n",
    "        a.set_yticklabels([])\n",
    "        a.set_aspect('equal')\n",
    "\n",
    "    plt.ion()\n",
    "    im_frac = sample_fig_ax[0].imshow(data.cmap(data.get_mean_pred()))\n",
    "    im_rounded = sample_fig_ax[1].imshow(data.cmap(data.get_prediction()))\n",
    "    plt.ioff()\n",
    "\n",
    "    factor_figs = [plt.figure(num=(i+1)*10, figsize=(1.0,1.0)) for i in range(20)]\n",
    "    for f in factor_figs:\n",
    "        f.add_subplot(1,1,1)\n",
    "        f.canvas.resizable=True\n",
    "\n",
    "    axs_fac = [fig.axes[0] for fig in factor_figs]\n",
    "\n",
    "    for a in axs_fac:\n",
    "        a.axis('off')\n",
    "        a.set_xticklabels([])\n",
    "        a.set_yticklabels([])\n",
    "        a.set_aspect('equal')\n",
    "\n",
    "    plt.subplots_adjust(wspace=0, hspace=0.3)\n",
    "    plt.ion()\n",
    "\n",
    "    factors_pos, factors_neg = data.plot_factors()\n",
    "    for i in range(10):\n",
    "        axs_fac[2*i].imshow(factors_neg[i])\n",
    "        axs_fac[2*i+1].imshow(factors_pos[i])\n",
    "    plt.ioff()\n",
    "    label_fig = plt.figure(3,(16,2))\n",
    "    axs_label = [label_fig.add_subplot(1,6, i+1) for i in range(6)]\n",
    "\n",
    "    label_fig.suptitle('Input and Label(s)', fontsize=12, x=0.18)\n",
    "\n",
    "    density_text = widgets.Text(\"Pseudo-Density: 1\")\n",
    "    density_text.disabled = True\n",
    "\n",
    "    plt.ion()\n",
    "    data.plot_sample_and_labels(axs_label)\n",
    "    for a in axs_label:\n",
    "        a.axis('off')\n",
    "        a.set_xticklabels([])\n",
    "        a.set_yticklabels([])\n",
    "        a.set_aspect('equal')\n",
    "\n",
    "    def show_sample():\n",
    "        sample_fig_ax[0].imshow(data.cmap(data.get_mean_pred()))\n",
    "\n",
    "        sample_fig_ax[1].imshow(data.cmap(data.get_prediction()))\n",
    "        sample_fig.canvas.draw_idle()\n",
    "\n",
    "    def update_density():\n",
    "        n = norm()\n",
    "        prod = 1\n",
    "        scale = n.pdf(0)\n",
    "        for s in slider_list[1:]:\n",
    "            prod *= n.pdf(s.value) / scale\n",
    "        density_text.value = f\"Pseudo-Density: {np.round(prod, 5)}\"\n",
    "\n",
    "    slider_headings = [widgets.HTML(f\"<font size='4'><center><b>Factor {i+1}:</b>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;0.00</center>\") for i in range(10)]\n",
    "\n",
    "    def update_slider_headings():    \n",
    "        for i in range(len(slider_headings)):\n",
    "            slider_headings[i].value = f\"<font size='4'><center><b>Factor {i+1}:</b>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;\" + '{:.2f}'.format(data.factor_weights[i]) +\"</center>\"\n",
    "\n",
    "    slider_list = [widgets.FloatSlider(value=0, min=-2, max=2, readout=False)]\n",
    "\n",
    "    src_dict = {}\n",
    "    src_dict[slider_list[-1].model_id] = \"Covariance\"\n",
    "\n",
    "    for i in range(data.factor_model.rank):\n",
    "        slider_list.append(widgets.FloatSlider(value=0, min=-2, max=2, readout=False))\n",
    "        src_dict[slider_list[-1].model_id] = f\"Factor {i+1}\"\n",
    "\n",
    "    def update(change):\n",
    "\n",
    "        src = src_dict[change[\"owner\"].model_id]\n",
    "\n",
    "        if src == \"Covariance\":\n",
    "            data.scale_covariance(change[\"new\"])\n",
    "        else:\n",
    "            factor_id = int(src.split(\" \")[-1])\n",
    "            new_value = change[\"new\"]\n",
    "            data.scale_factor(factor_id, new_value)\n",
    "        \n",
    "        update_density()\n",
    "        show_sample()  \n",
    "        update_slider_headings()\n",
    "\n",
    "\n",
    "    slider_list[0].observe(update, names='value')\n",
    "    for i in range(data.factor_model.rank):\n",
    "        slider_list[i].observe(update, names='value')\n",
    "\n",
    "\n",
    "    reset_btn = widgets.Button(description=\"Reset\")\n",
    "\n",
    "    img_dd = widgets.Dropdown(options=pkl_samples, layout={'width': 'max-content'})\n",
    "\n",
    "    rotations_dd = widgets.Dropdown(options=data.available_rotations)\n",
    "    def update_rotation(change):\n",
    "        if change['type'] == 'change' and change['name'] == 'value':\n",
    "            data.rotate_factors(change[\"new\"])\n",
    "            reset(None)\n",
    "            show_sample()\n",
    "            factors_pos, factors_neg  = data.plot_factors()\n",
    "            for i in range(10):\n",
    "                axs_fac[2*i].imshow(factors_neg[i])\n",
    "                axs_fac[2*i+1].imshow(factors_pos[i])\n",
    "\n",
    "\n",
    "    rotations_dd.observe(update_rotation, names=\"value\")\n",
    "\n",
    "    def update_img(change):\n",
    "        reset(None)\n",
    "        data.update_example(os.path.join(root_path, change[\"new\"]))\n",
    "        rotations_dd.value = \"Unrotated\"\n",
    "        factors_pos, factors_neg = data.plot_factors()\n",
    "        for i in range(10):\n",
    "            axs_fac[2*i].imshow(factors_neg[i])\n",
    "            axs_fac[2*i+1].imshow(factors_pos[i])\n",
    "\n",
    "        data.plot_sample_and_labels(axs_label)\n",
    "        show_sample()\n",
    "    img_dd.observe(update_img, names=\"value\")\n",
    "\n",
    "\n",
    "    def reset(b): \n",
    "        data.reset()\n",
    "        slider_list[0].unobserve_all()\n",
    "        slider_list[0].value = 1\n",
    "        slider_list[0].observe(update, names='value')\n",
    "\n",
    "        for i in range(len(slider_list[1:])):\n",
    "            slider_list[i+1].unobserve_all()\n",
    "            slider_list[i+1].value = 0\n",
    "            slider_list[i+1].observe(update, names='value')\n",
    "\n",
    "        update_slider_headings()\n",
    "        show_sample()\n",
    "        update_density()\n",
    "\n",
    "    reset_btn.on_click(reset)\n",
    "\n",
    "    def resample(b):\n",
    "        reset(None)\n",
    "        data.resample()\n",
    "        for i in range(len(slider_list[1:])):\n",
    "            slider_list[i+1].unobserve_all()\n",
    "            slider_list[i+1].value = data.factor_weights[i]\n",
    "            slider_list[i+1].observe(update, names='value')\n",
    "        \n",
    "        update_slider_headings()\n",
    "        show_sample()\n",
    "        update_density()\n",
    "\n",
    "    resample_btn = widgets.Button(description=\"Resample\")\n",
    "    resample_btn.on_click(resample)\n",
    "\n",
    "    for factor_fig in factor_figs:\n",
    "        factor_fig.canvas.toolbar_visible = False\n",
    "        factor_fig.canvas.header_visible = False\n",
    "        factor_fig.canvas.footer_visible = False\n",
    "\n",
    "    sample_fig.canvas.toolbar_visible = False\n",
    "    sample_fig.canvas.header_visible = False\n",
    "    sample_fig.canvas.footer_visible = False\n",
    "\n",
    "    label_fig.canvas.toolbar_visible = False\n",
    "    label_fig.canvas.header_visible = False\n",
    "    label_fig.canvas.footer_visible = False\n",
    "\n",
    "    main_v1 = widgets.HBox([rotations_dd, reset_btn, resample_btn, density_text])\n",
    "    main_v3 = widgets.VBox([sample_fig.canvas])\n",
    "    main_v2 = widgets.VBox([main_v1, \n",
    "        *[widgets.HBox([widgets.HBox([factor_figs[2*i].canvas, widgets.VBox([slider_headings[i], slider_list[i+1]]), factor_figs[2*i+1].canvas],layout=widgets.Layout(border='solid')), \n",
    "        widgets.HBox([factor_figs[2*i+2*5].canvas, widgets.VBox([slider_headings[i+5], slider_list[i+5+1]]), factor_figs[2*i+2*5+1].canvas],layout=widgets.Layout(border='solid'))]) for i in range(5)]])\n",
    "\n",
    "    main = widgets.HBox([main_v2, main_v3])\n",
    "widgets.VBox([img_dd, label_fig.canvas, main])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "5f7c65b4124957e14e4ccd35b435ba38486c2ceeca87c300fe3c944cabf810eb"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
