{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/julia/miniconda3/envs/bs2/lib/python3.8/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.24.3\n",
      "  warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy.stats import skew, kurtosis\n",
    "import os\n",
    "from sklearn import metrics\n",
    "import wandb\n",
    "import joblib\n",
    "from braivest.model.emgVAE import emgVAE\n",
    "import plotly.express as px\n",
    "import tensorflow as tf\n",
    "from optuna_utils import *\n",
    "from braivest.analysis.plotting_utils import plot_encodings, make_figure_nice\n",
    "\n",
    "from braivest.analysis.wandb_utils import load_wandb_model\n",
    "from scipy.spatial import procrustes\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def model_encode(model_path, trial_number, repeat_id, input_dim, study_config, test):\n",
    "    n_layers = int(study_config['n_layers'])\n",
    "    layer_dims = int(study_config['layer_dims'])\n",
    "    \n",
    "    layers = [layer_dims for layer in range(n_layers)]\n",
    "    model = emgVAE(input_dim = input_dim, latent_dim = 2, \n",
    "                   hidden_states = layers, kl = study_config['kl'], \n",
    "                   emg = False)\n",
    "\n",
    "    model.build((None, input_dim))\n",
    "    model.load_weights(os.path.join(model_path, f'model_weights_{repeat_id}_{trial_number}.h5'))\n",
    "    encoded = model.encode(test, numpy=True)\n",
    "    tf.keras.backend.clear_session()\n",
    "    del model    \n",
    "    return encoded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "artifact_dir = 'artifacts/synthetic_hmm:v0'\n",
    "train = np.load(os.path.join(artifact_dir, 'train.npy'))\n",
    "test = np.load(os.path.join(artifact_dir, 'test.npy'))\n",
    "test_hypno = np.load(os.path.join(artifact_dir, 'test_hypno.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "runs_path = 'neighbor_vae_expts/runs_and_models/vae_hmm/runs15-18/run15_VAE_HMM/models/'\n",
    "study = joblib.load(os.path.join(runs_path, f\"study_r_0.pkl\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "for index in [0, 1, 115]:\n",
    "    best_encodings = model_encode(runs_path, index, 0, 31, study.trials[index].params, test)\n",
    "    fig = plot_encodings(best_encodings, color=test_hypno)\n",
    "    fig = make_figure_nice(fig)\n",
    "    fig.update_layout(showlegend=False)\n",
    "    fig.update_layout(xaxis_title='Dimension 1', yaxis_title='Dimension 2')\n",
    "    fig.update_coloraxes(showscale=False)\n",
    "    fig.write_image('revised_images/figure1/model{}.svg'.format(index), format='svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = ['01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '50', '100', '150', '200', '250', '300']\n",
    "for epoch in epochs:\n",
    "    model = load_wandb_model(\"neighbor_vae_experiments/3kdqnvy3\", 31, epoch=epoch)\n",
    "    test_encodings = model.encode(test, numpy=True)\n",
    "    fig = plot_encodings(test_encodings, color=test_hypno)\n",
    "    fig = make_figure_nice(fig)\n",
    "    fig.update_coloraxes(showscale=False)\n",
    "    fig.update_layout(xaxis_title='Dimension 1', yaxis_title='Dimension 2')\n",
    "    fig.write_image('revised_images/figure1/epoch_{}.svg'.format(epoch), format='svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_hmm_vanilla = pd.read_csv('metrics_hmm_vanilla')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "comb_array = np.array(np.meshgrid(np.arange(216), np.arange(216))).T.reshape(-1, 2)\n",
    "comb_indices = np.random.choice(np.arange(comb_array.shape[0]), size=500, replace=False)\n",
    "delta_val_loss = []\n",
    "procrustes_vals = []\n",
    "avg_val_loss = []\n",
    "for index in comb_indices:\n",
    "    model_index1 = comb_array[index, 0]\n",
    "    model_index2 = comb_array[index, 1]\n",
    "    encodings1 = model_encode(runs_path, model_index1, 0, 31, study.trials[model_index1].params, test)\n",
    "    encodings2 = model_encode(runs_path, model_index2, 0, 31, study.trials[model_index2].params, test)\n",
    "    val_loss1 = metrics_hmm_vanilla.loc[model_index1, 'val_loss']\n",
    "    val_loss2 = metrics_hmm_vanilla.loc[model_index2, 'val_loss']\n",
    "    avg_val_loss.append((val_loss1+val_loss2)/2)\n",
    "    delta_val_loss.append(np.abs(val_loss1- val_loss2))\n",
    "    procrustes_vals.append(procrustes(encodings1, encodings2)[-1])\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = px.scatter(x=delta_val_loss, y=procrustes_vals, range_x=(0, 6), color=avg_val_loss, color_continuous_scale='Viridis')\n",
    "fig.update_traces(marker=dict(size=12, opacity=0.8))\n",
    "fig = make_figure_nice(fig)\n",
    "fig.update_layout(xaxis_title='delta Val loss', yaxis_title='Procrustes distance', xaxis=dict(tickvals=(0, 3, 6)), yaxis=dict(tickvals=(0, 0.5, 1)))\n",
    "fig.write_image('revised_images/figure1/deltavallossvsprocrustes.svg')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bs2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
