{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "from braivest.model.emgVAE import emgVAE\n",
    "import plotly.express as px\n",
    "import tensorflow as tf\n",
    "from optuna_utils import *\n",
    "from braivest.analysis.plotting_utils import plot_encodings, make_figure_nice\n",
    "import matplotlib.pyplot as plt\n",
    "import joblib\n",
    "from umap.umap_ import UMAP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "artifact_dir = 'artifacts/synthetic_hmm:v0'\n",
    "test = np.load(os.path.join(artifact_dir, 'test.npy'))\n",
    "test_hypno = np.load(os.path.join(artifact_dir, 'test_hypno.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "runs_path = 'runs_and_models/vae_hmm/runs15-18/run17_VAE_latent_3/models/'\n",
    "study = joblib.load(os.path.join(runs_path, f\"study_r_0.pkl\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "study.best_trial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encodings_3d = model_encode(runs_path, study.best_trial.number, 0, 31, study.best_trial.params, test, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = px.scatter_3d(encodings_3d, x=0, y=1, z=2, color=test_hypno)\n",
    "fig.update_traces(marker=dict(size=2, opacity=0.5))\n",
    "fig = make_figure_nice(fig)\n",
    "fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "embedding = reducer.fit_transform(encodings_3d)\n",
    "fig = plot_encodings(embedding, color=test_hypno)\n",
    "fig = make_figure_nice(fig)\n",
    "fig.update_layout(xaxis_title='UMAP Dimension 1', yaxis_title='UMAP Dimension 2')\n",
    "fig.show()\n",
    "fig.write_image('revised_images/figs1_3d_umap.svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "runs_path = 'neighbor_vae_expts/runs_and_models/vae_hmm/runs15-18/run18_VAE_latent_4/models/'\n",
    "study = joblib.load(os.path.join(runs_path, f\"study_r_0.pkl\"))\n",
    "encodings_4d = model_encode(runs_path, study.best_trial.number, 0, 31, study.best_trial.params, test, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encodings_4d.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP()\n",
    "embedding = reducer.fit_transform(encodings_4d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plot_encodings(embedding, color=test_hypno)\n",
    "fig = make_figure_nice(fig)\n",
    "fig.update_layout(xaxis_title='UMAP Dimension 1', yaxis_title='UMAP Dimension 2')\n",
    "fig.show()\n",
    "fig.write_image('revised_images/figs1_4d_umap.svg')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bs2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
