{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ea98285b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import sys, os\n",
    "sys.path.append('../')\n",
    "  \n",
    "from rnn.encoders import CNN_encoder\n",
    "from rnn.priors import LRRNN\n",
    "from rnn.vae import VAE\n",
    "from rnn.saving import save_model, load_model\n",
    "from rnn.datasets import SU_dataset, transform\n",
    "from rnn.evaluation import eval_VAE\n",
    "from evaluation.klx_gmm import calc_kl_from_data\n",
    "from evaluation.pse import *#power_spectrum_error_per_dim\n",
    "from evaluation.klx import klx_metric\n",
    "import scipy.ndimage as ndimage\n",
    "import scipy.signal as signal\n",
    "window = signal.windows.hann(15)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8cad6ee8",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_data = torch.from_numpy(np.float32(np.load('../data/EEG_train.npy')))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a8093e53",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([9640, 64])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d45340e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "class Trial_gen(Dataset):\n",
    "    def __init__(self, task_params):\n",
    "        self.data = torch.from_numpy(np.float32(np.load('../data/EEG_train.npy')))\n",
    "        self.data_eval = torch.from_numpy(np.float32(np.load('../data/EEG_train.npy')))\n",
    "\n",
    "        self.dur =task_params['dur']\n",
    "        self.n_trials = task_params['n_trials']\n",
    "        self.t_indices = torch.randint(low=0,high=self.data.shape[0]-self.dur,size=(self.n_trials,))\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.n_trials\n",
    "    def __getitem__(self, idx):\n",
    "        #print(idx)\n",
    "        #print(len(self.seeds))\n",
    "        t_start = self.t_indices[idx]\n",
    "        t_end = t_start + self.dur\n",
    "        return self.data[t_start:t_end].T, torch.zeros(0,self.dur)\n",
    "\n",
    "# plot example trial plus the latent signal underlying it\n",
    "batch_size=1\n",
    "task_params ={\"dur\":500,\n",
    "              \"n_trials\":1000,\n",
    "              \"name\":\"EEG\"\n",
    "             }\n",
    "task = Trial_gen(task_params)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "25747b9b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_09_39_20\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0010000000000000009\n",
      "KL_x = 2.144,  PS_corr = 0.988, PS_dist = 0.096, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_04_32_26\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0020000000000000018\n",
      "KL_x = 2.182,  PS_corr = 0.991, PS_dist = 0.134, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_08_T_23_44_53\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0\n",
      "KL_x = 1.802,  PS_corr = 0.987, PS_dist = 0.097, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_09_17_36\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0\n",
      "KL_x = 2.340,  PS_corr = 0.989, PS_dist = 0.095, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_16_11_44\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0010000000000000009\n",
      "KL_x = 1.828,  PS_corr = 0.987, PS_dist = 0.122, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_08_T_23_42_11\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0\n",
      "KL_x = 1.764,  PS_corr = 0.987, PS_dist = 0.112, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_04_43_23\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0\n",
      "KL_x = 2.737,  PS_corr = 0.991, PS_dist = 0.148, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_16_08_53\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0020000000000000018\n",
      "KL_x = 2.195,  PS_corr = 0.987, PS_dist = 0.081, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_08_T_23_44_08\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0010000000000000009\n",
      "KL_x = 2.256,  PS_corr = 0.984, PS_dist = 0.106, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_04_29_35\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0020000000000000018\n",
      "KL_x = 2.411,  PS_corr = 0.984, PS_dist = 0.100, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_08_T_23_42_48\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0020000000000000018\n",
      "KL_x = 1.965,  PS_corr = 0.988, PS_dist = 0.110, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_16_19_04\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.006000000000000005\n",
      "KL_x = 2.171,  PS_corr = 0.991, PS_dist = 0.104, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_04_35_55\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0020000000000000018\n",
      "KL_x = 2.323,  PS_corr = 0.987, PS_dist = 0.122, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_16_09_32\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0020000000000000018\n",
      "KL_x = 2.091,  PS_corr = 0.989, PS_dist = 0.133, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_09_15_08\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0020000000000000018\n",
      "KL_x = 2.276,  PS_corr = 0.991, PS_dist = 0.132, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_04_33_30\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0050000000000000044\n",
      "KL_x = 2.444,  PS_corr = 0.981, PS_dist = 0.102, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_09_23_32\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0050000000000000044\n",
      "KL_x = 2.425,  PS_corr = 0.987, PS_dist = 0.098, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_09_18_22\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0020000000000000018\n",
      "KL_x = 1.852,  PS_corr = 0.990, PS_dist = 0.108, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_16_32_54\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0010000000000000009\n",
      "KL_x = 2.448,  PS_corr = 0.985, PS_dist = 0.123, Mean_rate_error = 0.000\n",
      "EEG_Inv_Obs_PLRNN_Z_Date_32024_05_09_T_16_19_03\n",
      "using clipped ReLU activation\n",
      "using uniform init\n",
      "weight scaler 1\n",
      "Inv_Obs\n",
      "3\n",
      "n outliers: 0.0020000000000000018\n",
      "KL_x = 1.939,  PS_corr = 0.989, PS_dist = 0.089, Mean_rate_error = 0.000\n"
     ]
    }
   ],
   "source": [
    "\n",
    "directory = os.fsencode(\"../models/Sweep\")\n",
    "N = 20\n",
    "dim_zs = [3,4]\n",
    "data_kl = [[] for _ in range(2)]\n",
    "data_ph = [[] for _ in range(2)]\n",
    "noise_z = [[] for _ in range(2)]\n",
    "noise_x = [[] for _ in range(2)]\n",
    "\n",
    "orth_z = False\n",
    "\n",
    "for file in os.listdir(directory):\n",
    "    filename = os.fsdecode(file)\n",
    "    if filename.endswith(\"_vae_params.pkl\"): \n",
    "        model_name = filename.removesuffix(\"_vae_params.pkl\")\n",
    "        print(model_name)\n",
    "        vae, params, task_params, training_params=load_model(\"../models/Sweep/\"+model_name)\n",
    "        print(vae.dim_z)\n",
    "        klx_bin, _, psH,_ = eval_VAE(vae,task, smoothing = 20, cut_off = 2400, freq_cut_off = -1, sim_obs_noise=1,sim_latent_noise=True,smooth_at_eval=True)\n",
    "        data_kl[dim_zs.index(vae.dim_z)].append(klx_bin)\n",
    "        data_ph[dim_zs.index(vae.dim_z)].append(psH)\n",
    "        if orth_z:\n",
    "            m = vae.prior.transition.m_transform(vae.prior.transition.m).detach().numpy()\n",
    "            n = vae.prior.transition.n.detach().numpy()\n",
    "            J = m@n\n",
    "            U,s,V = np.linalg.svd(J)\n",
    "            proj_matrix = U[:,:m.shape[1]].T@m\n",
    "            cov_chol = vae.prior.chol_cov_embed(vae.prior.R_z).detach().numpy()\n",
    "            cov_chol = proj_matrix@cov_chol\n",
    "            cov = cov_chol@cov_chol.T\n",
    "            \n",
    "            B = vae.prior.observation.B.detach().numpy()\n",
    "            Readout_proj = B.T@((np.linalg.inv(m.T@m)@m.T))\n",
    "            norm = np.linalg.norm(Readout_proj)\n",
    "            noise_z[dim_zs.index(vae.dim_z)].append(np.mean(np.sqrt(cov[range(vae.dim_z),range(vae.dim_z)])).item()*norm)\n",
    "        noise_x[dim_zs.index(vae.dim_z)].append(torch.mean(vae.prior.std_embed_x(vae.prior.R_x)).item())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d13e0e93",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "5272e19c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import median_abs_deviation as mad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "f110967a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.10704101358077645 0.011428979866539318\n"
     ]
    }
   ],
   "source": [
    "print(np.median(data_ph[0]), mad(data_ph[0]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "9d9209b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.1886454820632935 0.22282493114471436\n"
     ]
    }
   ],
   "source": [
    "print(np.median(data_kl[0]), mad(data_kl[0]))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
