{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea98285b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "import sys, os\n",
    "sys.path.append('../')\n",
    "from rnn.vae import VAE\n",
    "from rnn.train import train_VAE\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from pyrnn.train import load_rnn\n",
    "import matplotlib.pyplot as plt\n",
    "from rnn.saving import save_model, load_model\n",
    "import matplotlib as mpl\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70755ec3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up directories\n",
    "out_dir = \"\"\n",
    "data_path = \"\"\n",
    "cuda=True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "594e63cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the teacher model\n",
    "\n",
    "rnn_osc,model_params,task_params,training_params = load_rnn(\"../data/osc_rnn40\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d2ea56d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Obtain and orthogonalise weights of the teacher model\n",
    "\n",
    "tr = 2 # rank 2\n",
    "U = torch.clone(rnn_osc.rnn.m.detach())\n",
    "V= torch.clone(rnn_osc.rnn.n.detach()*.1/model_params['n_rec'])\n",
    "W_or = U@V\n",
    "U, s, V = torch.linalg.svd(W_or, full_matrices=False)\n",
    "U, s,V = U[:, :tr], s[:tr], V[:tr, :]\n",
    "V = (V.T * s).T\n",
    "\n",
    "B = rnn_osc.rnn.b_rec.detach()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5bf2955",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist((np.random.randint(low=0,high=2,size=(20))-.5)*3+np.random.randn(20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2bd1d07",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a dataset using the teacher model\n",
    "\n",
    "class Trial_gen(Dataset):\n",
    "    def __init__(self, task_params):\n",
    "        self.task_params = task_params\n",
    "        self.dur =task_params['dur']\n",
    "        self.n_trials = task_params['n_trials']\n",
    "        self.N = task_params['n_neurons']\n",
    "        self.w = task_params['w']\n",
    "        self.non_lin=task_params['non_lin']\n",
    "        self.R_z = task_params['R_z']\n",
    "\n",
    "        ph0 = torch.randn(self.n_trials)*np.pi*2\n",
    "        r0  = torch.randn(self.n_trials)*2\n",
    "        self.latents = torch.zeros(2, self.n_trials, self.dur, dtype=torch.float32)\n",
    "        self.rates= torch.zeros(self.N, self.n_trials, self.dur, dtype=torch.float32)\n",
    "        self.latents[0,:,0] = r0*np.cos(ph0)\n",
    "        self.latents[1,:,0] = r0*np.sin(ph0)\n",
    "        self.latents[:,:,0]+=torch.randn(2,self.n_trials)*self.R_z\n",
    "\n",
    "        for t in range(1, self.dur):\n",
    "            self.latents[:,:,t]+=0.9*self.latents[:,:,t-1]\n",
    "            self.latents[:,:,t]+=V@self.non_lin(U@self.latents[:,:,t-1]+B.unsqueeze(1))+torch.randn(2,self.n_trials)*self.R_z\n",
    "\n",
    "        if task_params[\"out\"]==\"rates\":\n",
    "            for t in range(self.dur):\n",
    "                self.rates[:,:,t] = U@self.non_lin(self.latents[:,:,t]+B.unsqueeze(1))\n",
    "        elif task_params[\"out\"]==\"currents\":\n",
    "            for t in range(self.dur):\n",
    "                self.rates[:,:,t] = U@self.latents[:,:,t]\n",
    "        #proj_matrix = (torch.randint(low=0,high=2,size=(self.N,1,1))-.5)*task_params['B']#+torch.randn(self.N,1,1)*\n",
    "        self.rates*=task_params['B']\n",
    "        self.rates+=task_params['Bias']\n",
    "\n",
    "        if task_params['obs_rectify'] =='exp':\n",
    "            self.rates =torch.exp(self.rates)\n",
    "        elif task_params['obs_rectify'] =='relu':\n",
    "            self.rates= torch.relu(self.rates)+1e-10\n",
    "        elif task_params['obs_rectify'] =='softplus':\n",
    "            self.rates = torch.nn.functional.softplus(self.rates)\n",
    "        self.data = torch.poisson(self.rates)\n",
    "        self.data_eval= self.data\n",
    "    def __len__(self):\n",
    "        return self.n_trials\n",
    "      \n",
    "    def __getitem__(self, idx):\n",
    "        \"\"\"\n",
    "        Return a trial of length self.dur\n",
    "        Args:\n",
    "            idx (int): trial index, arbitrary as trials are sampled randomly\n",
    "        Returns:\n",
    "            trial (torch.tensor; dim_x x self.dur): trial of length self.dur\n",
    "        \"\"\"\n",
    "        \n",
    "        return self.data[:,idx], torch.zeros(0,self.dur,device=self.data.device)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0808ed9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot example trial plus the latent signal underlying it\n",
    "task_params ={\"dur\":200,\n",
    "              \"n_trials\":400,\n",
    "              \"name\":\"Sine\",\n",
    "              \"n_neurons\": 40,\n",
    "              \"out\":\"currents\",\n",
    "              \"w\":.1,\n",
    "              \"R_z\":0.2,\n",
    "              \"Bias\":-3,\n",
    "              \"B\":4,\n",
    "              \"non_lin\":torch.nn.ReLU(),\n",
    "               \"obs_rectify\":\"softplus\",\n",
    "             }\n",
    "task = Trial_gen(task_params)\n",
    "batch_size= 10\n",
    "data_loader = DataLoader(\n",
    "   task, batch_size=batch_size, shuffle=True\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fa322bd",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d387fbbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_i = 0\n",
    "rates = task.data[:,tr_i]\n",
    "logrates = task.rates[:,tr_i]\n",
    "latent_code = task.latents[:,tr_i]\n",
    "fig,ax = plt.subplots(3,figsize=(7,5))\n",
    "indices = np.argsort(np.argmax(logrates[:,-50:],1))\n",
    "#indices = torch.arange(0,20)\n",
    "im = ax[2].imshow(rates[indices],aspect='auto',interpolation='none',cmap='Greys')\n",
    "fig.colorbar(im)\n",
    "for i,ind in enumerate(indices.numpy()):\n",
    "    ax[1].plot(logrates[ind], alpha=np.linspace(0.1,1,40)[i],color='black')\n",
    "    #ax[1].plot(rates[ind], alpha=np.linspace(0.1,1,20)[i],color='black')\n",
    "\n",
    "ax[0].plot(latent_code[0])\n",
    "ax[0].plot(latent_code[1])\n",
    "ax[2].set_xlabel(\"time steps\")\n",
    "ax[0].set_xticks([])\n",
    "ax[1].set_xticks([])\n",
    "ax[2].set_ylabel(\"neurons\")\n",
    "ax[1].set_ylabel(\"neuron activity\")\n",
    "ax[0].set_ylabel('latents')\n",
    "\n",
    "ax[0].spines[['right','top']].set_visible(False)\n",
    "ax[1].spines[['right','top']].set_visible(False)\n",
    "ax[2].spines[['right','top']].set_visible(False)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c48e4f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('mean spike rate assuming 5 ms bins: ' +str(task.data.float().mean()*200))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e40ab22",
   "metadata": {},
   "source": [
    "## Create a VAE RNN setup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "063f90d9",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a268f88f",
   "metadata": {},
   "outputs": [],
   "source": [
    "task_params[\"out\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d465fb1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_z = 2\n",
    "dim_N=40\n",
    "dim_x = task_params['n_neurons']\n",
    "bs = 10\n",
    "cuda = False\n",
    "n_epochs = 1500\n",
    "wandb=False\n",
    "# initialise encoder\n",
    "enc_params = {\n",
    "    \"init_kernel_sizes\":[21,11,1],\n",
    "    \"nonlinearity\":'gelu',\n",
    "    \"n_channels\":[64,64],\n",
    "    \"init_scale\":.1,\n",
    "    \"constant_var\":False,\n",
    "    \"n_hidden\":64,\n",
    "    \"strides\":[1,1,1],\n",
    "    'dilations':[1,1,1],\n",
    "    \"padding\":'same',\n",
    "    \"padding_mode\":\"circular\",\n",
    "    }\n",
    "\n",
    "# initialise prior\n",
    "prior_params={\n",
    "    \"clipped\":False,\n",
    "    \"train_noise_obs\":False,\n",
    "    \"train_noise_prior\":True,\n",
    "    \"train_noise_prior_t0\":True,\n",
    "    \"init_noise_z\":.1,\n",
    "    \"init_noise_z_t0\":.1,\n",
    "    \"init_noise_x\":.1,\n",
    "    \"scalar_noise_z\": False,\n",
    "    \"scalar_noise_x\":False,\n",
    "    \"scalar_noise_z_t0\":False,\n",
    "    'identity_readout':True,\n",
    "    'activation':\"relu\",\n",
    "    \"exp_par\":True,\n",
    "    \"shared_tau\":.9,\n",
    "    \"readout_rates\": task_params[\"out\"],\n",
    "    \"train_obs_bias\":True,\n",
    "    \"train_obs_weights\":True,\n",
    "    \"train_latent_bias\":False,\n",
    "    \"train_neuron_bias\":True,\n",
    "    \"orth\":False,\n",
    "    \"m_norm\":False,\n",
    "    \"weight_dist\":\"uniform\",\n",
    "    \"weight_scaler\":1,#/dim_N,\n",
    "    'initial_state':'trainable'\n",
    "    }\n",
    "\n",
    "\n",
    "training_params = {\n",
    "    \"lr\":1e-3,\n",
    "    \"lr_end\":1e-5,\n",
    "    \"opt_eps\":1e-8,\n",
    "    \"CosineRestarts\":0,\n",
    "    \"beta\":0.5,\n",
    "    \"n_epochs\":n_epochs,\n",
    "    \"regularisation\":\"None\",\n",
    "    \"regularisation_params\": [1e-4,0],#ratio, lambda\n",
    "    \"annealing\":False,\n",
    "    \"annealing_epochs\":500,\n",
    "    \"grad_norm\":0,\n",
    "    \"eval_epochs\":50,\n",
    "    \"batch_size\":bs,\n",
    "    \"cuda\":cuda,\n",
    "    'smoothing':20,\n",
    "    'freq_cut_off':10000,\n",
    "    \"sim_obs_noise\":0,\n",
    "    \"sim_latent_noise\":1,\n",
    "    \"k\":64,\n",
    "    \"loss_f\":\"VGTF\",\n",
    "    \"MC_q\":True,\n",
    "    \"dreg_q\":\"none\",\n",
    "    \"MC_p\":True,\n",
    "    \"dreg_p\":\"none\",\n",
    "    \"resample\":\"systematic\",#, multinomial or none\"\n",
    "    \"L2_reg\":0,\n",
    "    \"observation_likelihood\": \"Poisson\", # observation likelihood\n",
    "    \"bootstrap\":False,\n",
    "    \"alpha\":1,\n",
    "    \"alpha_decay\":.999,\n",
    "    \"alpha_method\":\"mean\",\n",
    "    \"alpha_update_interval\":5,\n",
    "    \"run_eval\":False,\n",
    "    \"t_forward\":0\n",
    "\n",
    "}\n",
    "\n",
    "\n",
    "VAE_params = {\n",
    "    \"dim_x\":dim_x, \n",
    "    \"dim_z\":dim_z,\n",
    "    \"dim_N\":dim_N,\n",
    "    \"enc_architecture\":\"CNN_causal\",\n",
    "    \"enc_params\":enc_params,\n",
    "    \"prior_architecture\":\"PLRNN\",\n",
    "    \"prior_params\":prior_params,\n",
    "    \"causal\":True,\n",
    "    \"obs_rectify\":task_params['obs_rectify'],\n",
    "    }\n",
    "vae=VAE(VAE_params)\n",
    "\n",
    "#with torch.no_grad():\n",
    "    #vae.prior.latent_step.m.copy_(U)\n",
    "    #vae.prior.latent_step.m.requires_grad=False\n",
    "    #vae.prior.observation.Bias.copy_(torch.ones_like(vae.prior.observation.Bias)*task_params['Bias'])\n",
    "    #vae.prior.observation.B.copy_(vae.prior.observation.B*task_params['B'])\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86916eb6",
   "metadata": {},
   "source": [
    "## Train the VAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0bfc852",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb = False\n",
    "training_params[\"lr\"]=1e-3\n",
    "train_VAE(\n",
    "    vae,\n",
    "    training_params,\n",
    "    task,\n",
    "    sync_wandb=wandb,\n",
    "    out_dir=out_dir,\n",
    "    fname=None\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7f522cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_model(vae, training_params, task_params,name =\"../models/Sine_40_1000_v5\")#, directory = '../models/')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "691b0eb4",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9e4b48a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Sine_CNN_causal_PLRNN_Z_Date_22024_05_11_T_11_01_24 trained 1000 epochs\n",
    "#Saving model as Sine_CNN_causal_PLRNN_Z_Date_22024_05_11_T_13_11_25 more particles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3de62cb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Sine_CNN_causal_PLRNN_Z_Date_22024_05_08_T_09_18_19"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f05fde3",
   "metadata": {},
   "outputs": [],
   "source": [
    "m_or = vae.prior.transition.m #20,2\n",
    "n_or = vae.prior.transition.n\n",
    "J = m_or@n_or\n",
    "u,s,v = torch.linalg.svd(J)\n",
    "projection_matrix = u[:,:2].T@m_or\n",
    "proj_chol = projection_matrix@torch.diag(vae.prior.std_embed_z(vae.prior.R_z))\n",
    "print((vae.prior.std_embed_z(vae.prior.R_z)))\n",
    "print(((proj_chol@proj_chol.T)))\n",
    "print(torch.sqrt((proj_chol@proj_chol.T)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0d61017",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(proj_chol@proj_chol.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23b7d72c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(torch.mean(vae.prior.observation.Bias))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40d7658f",
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.prior.transition.cast_A(vae.prior.transition.AW)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24aba0ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.prior.std_embed_z(vae.prior.R_z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0a5bc36",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get some data for plotting\n",
    "\n",
    "cl = vae.encoder.cut_len\n",
    "i=np.random.randint(0,bs)\n",
    "with torch.no_grad():\n",
    "    data= task.data[:,i]\n",
    "    z_hat, Emean,Esigma,eps_s = vae.encoder(data.unsqueeze(0))\n",
    "    z0 = z_hat[:,:,:1].squeeze()\n",
    "    dur = z_hat.shape[2]\n",
    "\n",
    "    Z = vae.prior.get_latent_time_series(time_steps=dur, z0=z0)\n",
    "    data_gen = vae.prior.get_observation(Z,noise_scale=0).permute(0,2,1,3).reshape(dur, dim_x)\n",
    "    if task_params['obs_rectify'] =='exp':\n",
    "        data_gen =torch.exp(data_gen)\n",
    "    elif task_params['obs_rectify'] =='relu':\n",
    "        data_gen= torch.relu(data_gen)+1e-10\n",
    "    elif task_params['obs_rectify'] =='softplus':\n",
    "        data_gen = torch.nn.functional.softplus(data_gen)\n",
    "    latents= task.latents[:,i]\n",
    "    logrates= task.rates[:,i]\n",
    "    dim_x,T = data.shape\n",
    "    z0 = task.latents[:,i,cl].unsqueeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e386947",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(logrates.cpu()[:,cl:])\n",
    "plt.colorbar()\n",
    "plt.title(\"True rates\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ea98ace",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plt.imshow(data_gen.cpu().T)\n",
    "plt.colorbar()\n",
    "plt.title(\"predictd log rates\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d5c3681",
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.prior.observation??"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dcef971",
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.prior.observation.B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13c3333d",
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.prior.observation.Bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87e94e7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(latents.detach().cpu().numpy()[0,cl:],label = 'true', color= 'orange')\n",
    "plt.plot(latents.detach().cpu().numpy()[1,cl:],label = 'true', color= 'blue')\n",
    "\n",
    "zs=[]\n",
    "\n",
    "for i in range(10):\n",
    "    Z = vae.prior.get_latent_time_series(time_steps=T-cl, z0=z0)\n",
    "    plt.plot(Z.detach().cpu().numpy()[0,0,:,0], color ='black',alpha=.2)\n",
    "    plt.plot(Z.detach().cpu().numpy()[0,1,:,0], color ='black',alpha=.2)\n",
    "    zs.append(Z.detach().cpu().numpy()[0,:,:,0])\n",
    "\n",
    "plt.plot(np.mean(zs,axis=0)[0], color ='red',label = 'mean of infer')\n",
    "plt.plot(np.mean(zs,axis=0)[1], color ='teal',label = 'mean of infer')\n",
    "plt.legend()\n",
    "\n",
    "data_gen = vae.prior.get_observation(Z,noise_scale=0)\n",
    "if task_params['obs_rectify'] =='exp':\n",
    "    data_gen =torch.exp(data_gen)\n",
    "elif task_params['obs_rectify'] =='relu':\n",
    "    data_gen= torch.relu(data_gen)+1e-10\n",
    "elif task_params['obs_rectify'] =='softplus':\n",
    "    data_gen = torch.nn.functional.softplus(data_gen)\n",
    "data_gen = data_gen.permute(0,2,1,3).reshape(T-cl, dim_x)\n",
    "plt.figure()\n",
    "plt.plot(logrates.detach().cpu().numpy()[0,cl:], color= 'orange',label = 'true')\n",
    "plt.plot(logrates.detach().cpu().numpy()[1,cl:],color='blue')\n",
    "plt.plot(data_gen.detach().cpu().numpy()[:,0], color='red',label = 'gen')\n",
    "plt.plot(data_gen.detach().cpu().numpy()[:,1], color='green')\n",
    "plt.legend()\n",
    "#plt.plot(torch.poisson(torch.exp(data_gen.detach())).cpu().numpy()[:,0], color='red')\n",
    "#plt.plot(torch.poisson(torch.exp(data_gen.detach())).cpu().numpy()[:,1], color='green')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9e20a2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 10000\n",
    "R = 3\n",
    "D = 5\n",
    "print((D**R) * comb(N,R))\n",
    "r1 = (D**R) * comb(N,R)\n",
    "r = 0\n",
    "for i in range(0,N):\n",
    "    r+=comb(i,R-1)\n",
    "    #D*comb(D*(N-i),R-1)\n",
    "r*= (D**R)\n",
    "print(r)\n",
    "print(r/r1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d77b2a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38475b9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dur = 150\n",
    "with torch.no_grad():\n",
    "    data= task.__getitem__(0)\n",
    "    dim_x,_ = data.shape\n",
    "    z_hat, Emean,Esigma,eps_s = vae.encoder(data.unsqueeze(0))\n",
    "    print(torch.mean(torch.exp(Esigma/2)))\n",
    "    z0 = z_hat[:,:,:1].squeeze()\n",
    "    Z = vae.prior.get_latent_time_series(time_steps=dur, z0=z0)\n",
    "    data_gen = vae.prior.get_observation(Z,noise_scale=0).permute(0,2,1,3).reshape(dur, dim_x)\n",
    "plt.figure()\n",
    "plt.plot(Z[0,:,:,0].detach().cpu().T);\n",
    "plt.xlim(0)\n",
    "plt.figure()\n",
    "plt.plot(data_gen.detach().cpu());\n",
    "plt.figure()\n",
    "color =['orange','red']\n",
    "i=0\n",
    "plt.plot(Emean[0,i,:,0].detach().cpu().T,color=color[i]);\n",
    "plt.plot(Z[0,i,:z_hat.shape[2],0].detach().cpu().T,ls='--',color=color[i]);\n",
    "i=1\n",
    "plt.plot(Emean[0,i,:,0].detach().cpu().T,color=color[i], label = 'Emean');\n",
    "plt.plot(Z[0,i,:z_hat.shape[2],0].detach().cpu().T,ls='--',color=color[i],label='Z');\n",
    "plt.plot(np.exp(Esigma[0,0,:,0].detach().cpu()).T,color='blue',label='Esigma');\n",
    "plt.plot(data.T,color='black',alpha=.1)\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b184fa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save_model(vae, training_params, task_params,name =\"../models/Poisson_Rz12\")#, directory = '../models/')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f25ec83c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a51dc483",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a9aacda",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
