{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea98285b",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import sys, os\n",
    "sys.path.append('../')\n",
    "\n",
    "from rnn.vae import VAE\n",
    "from rnn.saving import save_model, load_model\n",
    "from evaluation.klx_gmm import calc_kl_from_data\n",
    "from evaluation.pse import *\n",
    "import scipy.ndimage as ndimage\n",
    "import scipy.signal as signal\n",
    "window = signal.windows.hann(15)\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b05c011",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "name = \"EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_16_19_03\"\n",
    "with torch.no_grad():\n",
    "    vae, params, task_params, training_params=load_model(name)\n",
    "    Z = vae.prior.get_latent_time_series(time_steps=6000,cut_off=0,noise_scale=1)\n",
    "    data_gen = vae.prior.get_observation(Z,noise_scale=1)\n",
    "plt.figure()\n",
    "plt.plot(Z[0,:,500:1000,0].detach().T,alpha=.7);\n",
    "plt.gca().spines[['right','top']].set_visible(False)\n",
    "plt.xlim(0)\n",
    "plt.figure()\n",
    "plt.plot(data_gen[0,:10,:,0].detach().T,alpha=0.2,lw=.1);\n",
    "plt.xlim(0)\n",
    "plt.gca().spines[['right','top']].set_visible(False)\n",
    "plt.xlim(0)\n",
    "plt.xlabel(\"time steps\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb4d2b82",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "raw_dat = np.load(\"../data/EEG_data_zscored.npy\")\n",
    "smooth_dat = np.load(\"../data/EEG_train.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0c2ec42",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "t1=1200\n",
    "with torch.no_grad():\n",
    "    vae, params, task_params, training_params=load_model(\"../models/Sweep4/\"+name)\n",
    "    z0 =vae.prior.inv_observation(torch.from_numpy(raw_dat[t1]).reshape(1,64,1).float())\n",
    "    Z = vae.prior.get_latent_time_series(time_steps=800,cut_off=0,noise_scale=1)\n",
    "    data_gen = vae.prior.get_observation(Z,noise_scale=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "130b4b78",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "raw_dat = np.load(\"../data/EEG_data_zscored.npy\")\n",
    "smooth_dat = np.load(\"../data/EEG_train.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f7be282",
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "y=zscore(ndimage.convolve1d(data_gen[0,:,:,0], window,axis=1),axis=1)\n",
    "plt_chs = [2,9,32,49, 61]\n",
    "fig, axs = plt.subplots(1,2, figsize=(3,1))\n",
    "for i ,ch_n in enumerate(plt_chs):\n",
    "    axs[1].plot(data_gen[0,ch_n,:,0].T+i*4,lw=1,color='slategrey');\n",
    "    #axs[1].plot(y[ch_n,:]+i*4,color='mediumturquoise',lw=1)\n",
    "\n",
    "y=zscore(ndimage.convolve1d(raw_dat.T, window,axis=1),axis=1)\n",
    "plt_chs = [2,9,32,49, 61]\n",
    "t1=1500\n",
    "for i ,ch_n in enumerate(plt_chs):\n",
    "    axs[0].plot(raw_dat[t1:800+t1,ch_n].T+i*4,lw=1,color='slategrey')\n",
    "    #axs[0].plot(y[ch_n,t1:800+t1]+i*4,color='mediumturquoise',lw=1)\n",
    "\n",
    "axs[0].set_xlim(0,800)\n",
    "axs[1].set_xlim(0,800)\n",
    "axs[0].set_yticks(range(0,20,4))\n",
    "axs[1].set_yticks(range(0,20,4))\n",
    "axs[0].set_yticklabels([channels[ch_n].strip('.') for ch_n in plt_chs]);\n",
    "axs[1].set_yticklabels([])\n",
    "axs[0].set_xticks(np.arange(0,160*5+1,160))\n",
    "axs[0].set_xticklabels(np.arange(0,1*5+1,1))\n",
    "axs[1].set_yticklabels([])\n",
    "axs[1].set_xticks(np.arange(0,160*5+1,160))\n",
    "axs[1].set_xticklabels(np.arange(0,1*5+1,1))\n",
    "axs[0].set_xlabel(\"time (s)\")\n",
    "axs[0].set_title(\"EEG\")\n",
    "axs[1].set_title(\"generated\")\n",
    "axs[0].set_ylabel(\"channel name\")\n",
    "\n",
    "plt.savefig(\"../figures/fig_EEG.pdf\",bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
