{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea98285b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset\n",
    "import sys, os\n",
    "#print(\"Filepath\")\n",
    "#print(str(os.path.dirname(os.path.abspath(__file__))))\n",
    "sys.path.append('../')\n",
    "from rnn.vae import VAE\n",
    "from rnn.train import train_VAE\n",
    "from rnn.datasets import SU_dataset, transform,Basic_dataset\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from pyrnn.model import RNN, predict\n",
    "from pyrnn.train import train_rnn\n",
    "from pyrnn.train import save_rnn, load_rnn\n",
    "import matplotlib.pyplot as plt\n",
    "from rnn.saving import save_model, load_model\n",
    "import matplotlib as mpl\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a4031f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "train=False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "218e54cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "out_dir = \"\"\n",
    "data_path = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c991c226",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "model_params = {\n",
    "    \"nonlinearity\": \"relu\",           # activation function\n",
    "    \"rank\": 2,                        # rank, set to 0 for full rank\n",
    "    \"n_inp\":1,                 # amount of input units\n",
    "    \"p_inp\": 1,                       # probability of connection input\n",
    "    \"n_rec\": 40,                     # amount of recurrent units\n",
    "    \"p_rec\": 1,                       # probability of connection recurrent\n",
    "    \"n_out\": 1,                       # amount of output units\n",
    "    \"scale_w_inp\": 1,                 # scale input weights\n",
    "    \"scale_w_out\": 1,                 # scale output weights\n",
    "    \"w_rec_dist\": \"gauss\",            # recurrent weight dist, gauss or gamma\n",
    "    \"spectr_rad\": 1,                # gain param, recurrent weights\n",
    "    \"spectr_norm\": True,              # use spectral normalisation on rec weights\n",
    "    \"train_w_inp\": True,              # train input weights\n",
    "    \"train_w_inp_scale\": False,       # train input scaling factor\n",
    "    \"train_w_rec\": True,              # train recurrent weights\n",
    "    \"train_b_rec\": True,             # train recurrent bias\n",
    "    \"train_taus\": False,              # train time constants\n",
    "    \"train_m\": False,                  # train m\n",
    "    \"train_n\": True,\n",
    "    \"scale_n\":1,\n",
    "    \"scale_m\":1,\n",
    "    \"loadings\":None,\n",
    "    \"cov\":None,\n",
    "    \"train_w_out\": True,              # train output weights\n",
    "    \"train_w_out_scale\": False,       # train output scaling factor\n",
    "    \"train_x0\": True,                 # train initial state\n",
    "    \"tau_lims\": [100],                # tau limits (min, max) or (value) in ms\n",
    "    'project_taus':'sigmoid',         # choice of projection map to keep within limits (\"exp\", \"sigmoid\" or \"clip\")\n",
    "    'tau_mean': 100,                  # if tau distribution, specify mean\n",
    "    'tau_std':1,                       # if tau distribution, specify std\n",
    "    \"dt\": 10,                         # timestep in ms\n",
    "    \"noise_std\": 0.05,                # noise std\n",
    "    \"scale_x0\": 0.1,                  # std of initial state, if gaussian\n",
    "    \"conn_mask\":None,                 # connection mask (tensor of size n_rec*n_rec)\n",
    "    \"dale_mask\":None,                 # dale mask (tensor of size n_rec*n_rec with only -1 and 1's on diagonal)\n",
    "}\n",
    "training_params = {\n",
    "    \"n_epochs\": 500,                  # number of passes through possible trials\n",
    "    \"lr\": 10e-3,                       # learning rate\n",
    "    \"batch_size\": 4,                  # batch size\n",
    "    \"clip_gradient\": 1,               # to avoid explosion of gradients\n",
    "    \"cuda\": False,                     # train on GPU if True\n",
    "    \"loss_fn\": \"mse\",                 # loss function (mse, cos or none)\n",
    "    \"optimizer\": \"adam\",              # optimizer (adam)\n",
    "    \"osc_reg_cost\": 0,                # oscillatory regularisation weight\n",
    "    \"osc_reg_freq\": 2,                # oscillatory regularisation frequency\n",
    "    \"offset_reg_cost\": 0,             # offset regularisation cost\n",
    "    \"l2_rates_reg\":1,                  # l2 regularisation on rates\n",
    "\n",
    "}\n",
    "task_params = {\n",
    "    \"stim_ons\": 20,                   # stimulus onset (units are number of time steps)\n",
    "    \"rand_ons\": 0,                    # randomise onset with this amount\n",
    "    \"stim_dur\": 20,                   # stimulus duration\n",
    "    \"stim_offs\": 0,                  # stimulus offsets\n",
    "    \"delay\": 100,                     # delay length\n",
    "    \"rand_delay\": 0,                  # randomize delay with this amount\n",
    "    \"probe_dur\": 20,                  # probe duration\n",
    "    \"probe_offs\": 0,                 # probe offset\n",
    "    \"response_dur\": 20,               # response duration\n",
    "    \"response_ons\": 0,               # response onset\n",
    "    \"seq_len\": 1,                     # sequence length\n",
    "    \"n_channels\": 1,            # number of stimulus\n",
    "}\n",
    "\n",
    "\n",
    "class Reaching(Dataset):\n",
    "    def __init__(self, task_params):\n",
    "        \"\"\"\n",
    "        Initialize a Reaching task\n",
    "        Args:\n",
    "            task_params: dictionary containing task parameters\n",
    "        \"\"\"\n",
    "\n",
    "        self.task_params = task_params\n",
    "        \n",
    "\n",
    "\n",
    "    def __len__(self):\n",
    "        \"\"\"Arbitrary number of trials, as they are randomly generated anyway\"\"\"\n",
    "        return 50\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        \"\"\"\n",
    "        Returns a trial\n",
    "\n",
    "        Args:\n",
    "            idx, trial index\n",
    "\n",
    "        Returns:\n",
    "            input, Tensor of size [seq_len, n_inp]\n",
    "            target, Tensor of size [seq_len, n_inp]\n",
    "            mask, Tensor of size [seq_len, n_inp]\n",
    "        \"\"\"\n",
    "        inputs = torch.zeros(200,1)\n",
    "        targets = np.sin(torch.linspace(0,8*np.pi,200)).unsqueeze(1)\n",
    "        mask = torch.ones(200,1)\n",
    "        return inputs,targets,mask\n",
    "sine_task = Reaching(task_params)\n",
    "\n",
    "rnn_osc = RNN(model_params)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae04d4cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y, m =sine_task[0]\n",
    "plt.plot(x)\n",
    "plt.plot(y)\n",
    "plt.plot(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "594e63cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "if train:\n",
    "    losses, reg_losses = train_rnn(rnn_osc, training_params, sine_task, sync_wandb=False)\n",
    "else:\n",
    "    rnn_osc,model_params,task_params,training_params = load_rnn(\"../data/osc_rnn\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb16fb01",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save_rnn(\"../data/osc_rnn40\",rnn_osc,model_params,task_params,training_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9eaf2e0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "rates, pred = predict(rnn_osc, torch.zeros(1000,1))\n",
    "\n",
    "fig, axs = plt.subplots(figsize=(8, 2))\n",
    "axs.plot(pred[0,:,:])\n",
    "axs.set_xlabel(\"timesteps\")\n",
    "\n",
    "U = rnn_osc.rnn.m.detach()\n",
    "V= rnn_osc.rnn.n.detach()*.1/model_params['n_rec']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c239054e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(rnn_osc.rnn.nonlinearity(torch.from_numpy(rates[0])));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71b077dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "U = torch.clone(rnn_osc.rnn.m.detach())\n",
    "V= torch.clone(rnn_osc.rnn.n.detach()*.1/model_params['n_rec'])\n",
    "W_or = U@V\n",
    "plt.imshow(W_or.detach())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d2ea56d",
   "metadata": {},
   "outputs": [],
   "source": [
    "tr = 2\n",
    "U, s, V = torch.linalg.svd(W_or, full_matrices=False)\n",
    "U, s,V = U[:, :tr], s[:tr], V[:tr, :]\n",
    "V = (V.T * s).T\n",
    "W=U@V\n",
    "print(torch.linalg.norm(W_or-W)) #should be 0\n",
    "plt.imshow(W.detach())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4194e1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.linalg.norm(U.T@U-torch.eye(U.shape[1])) #should be identity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9d13f9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "B = rnn_osc.rnn.b_rec.detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad2c48c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "z = rates[0]@U.numpy()\n",
    "fig, ax = plt.subplots(1,2,figsize=(6,3))\n",
    "ax[0].plot(z)\n",
    "ax[1].plot(z[:,0],z[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb72b5e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def dF(r, theta):\n",
    "    return 0.5*(r - r**3), 1\n",
    "\n",
    "X, Y = np.meshgrid(np.linspace(-3.0, 3.0, 30), np.linspace(-3.0, 3.0, 30))\n",
    "u, v = np.zeros_like(X), np.zeros_like(X)\n",
    "NI, NJ = X.shape\n",
    "\n",
    "for i in range(NI):\n",
    "    for j in range(NJ):\n",
    "        x, y = X[i, j], Y[i, j]\n",
    "        r, theta = (x**2 + y**2)**0.5, np.arctan2(y, x)\n",
    "        fp = dF(r, theta)\n",
    "        u[i,j] = (r + fp[0]) * np.cos(theta + fp[1]) - x\n",
    "        v[i,j] = (r + fp[0]) * np.sin(theta + fp[1]) - y\n",
    "\n",
    "plt.streamplot(X, Y, u, v)\n",
    "plt.axis('square')\n",
    "plt.axis([-3, 3, -3, 3])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2bd1d07",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Trial_gen(Dataset):\n",
    "    def __init__(self, task_params):\n",
    "\n",
    "        self.dur =task_params['dur']\n",
    "        self.n_trials = task_params['n_trials']\n",
    "        self.N = task_params['n_neurons']\n",
    "        self.w = task_params['w']\n",
    "        self.non_lin=task_params['non_lin']\n",
    "        self.R_z = task_params['R_z']\n",
    "        self.R_x = task_params['R_x']\n",
    "\n",
    "\n",
    "        #Dz, bs, T\n",
    "\n",
    "        ph0 = torch.randn(self.n_trials)*np.pi*2\n",
    "        r0  = torch.randn(self.n_trials)*2\n",
    "        self.latents = torch.zeros(2, self.n_trials, self.dur, dtype=torch.float32)\n",
    "        self.data = torch.zeros(self.N, self.n_trials, self.dur, dtype=torch.float32)\n",
    "        self.latents[0,:,0] = r0*np.cos(ph0)\n",
    "        self.latents[1,:,0] = r0*np.sin(ph0)\n",
    "        self.latents[:,:,0]+=torch.randn(2,self.n_trials)*self.R_z\n",
    "\n",
    "        #print(self.latents.shape)\n",
    "        for t in range(1, self.dur):\n",
    "            self.latents[:,:,t]+=0.9*self.latents[:,:,t-1]\n",
    "            #rint(V.shape, self.latents[:,:,t-1].shape, B.shape, U.shape, self.latents[:,:,t].shape)\n",
    "            #torch.Size([2, 20]) torch.Size([2, 200]) torch.Size([20]) torch.Size([20, 2]) torch.Size([2, 200])\n",
    "\n",
    "            self.latents[:,:,t]+=V@self.non_lin(U@self.latents[:,:,t-1]+B.unsqueeze(1))+torch.randn(2,self.n_trials)*self.R_z\n",
    "\n",
    "        if task_params[\"out\"]==\"rates\":\n",
    "            for t in range(self.dur):\n",
    "                self.data[:,:,t] = self.non_lin(U@self.latents[:,:,t]+B.unsqueeze(1))\n",
    "        elif task_params[\"out\"]==\"currents\":\n",
    "            for t in range(self.dur):\n",
    "                self.data[:,:,t] = U@self.latents[:,:,t]\n",
    "        self.data+=torch.randn(self.N,self.n_trials,self.dur)*self.R_x\n",
    "        self.data_eval = self.data\n",
    "    def __len__(self):\n",
    "        return self.n_trials\n",
    "      \n",
    "    def __getitem__(self, idx):\n",
    "        \"\"\"\n",
    "        Return a trial of length self.dur\n",
    "        Args:\n",
    "            idx (int): trial index, arbitrary as trials are sampled randomly\n",
    "        Returns:\n",
    "            trial (torch.tensor; dim_x x self.dur): trial of length self.dur\n",
    "        \"\"\"\n",
    "        \n",
    "        return self.data[:,idx], torch.zeros(0,self.data.shape[2],device=self.data.device)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0808ed9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot example trial plus the latent signal underlying it\n",
    "batch_size=4\n",
    "task_params ={\"dur\":75,\n",
    "              \"n_trials\":400,\n",
    "              \"name\":\"Sine\",\n",
    "              \"n_neurons\": 20,\n",
    "              \"out\":\"currents\",\n",
    "              \"w\":.1,\n",
    "              \"R_x\":.1,\n",
    "              \"R_z\":.2,\n",
    "              \"non_lin\":nn.ReLU(),\n",
    "             }\n",
    "task = Trial_gen(task_params)\n",
    "data_loader = DataLoader(\n",
    "   task, batch_size=batch_size, shuffle=True\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fa322bd",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d387fbbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_i = 0\n",
    "rates = task.data[:,tr_i]\n",
    "latent_code = task.latents[:,tr_i]\n",
    "fig,ax = plt.subplots(1,figsize=(7,5))\n",
    "ax.plot(latent_code[0].numpy(),latent_code[1].numpy())\n",
    "\n",
    "\n",
    "fig,ax = plt.subplots(3,figsize=(7,5))\n",
    "ax[2].imshow(rates,aspect='auto',interpolation='none')\n",
    "ax[1].plot(rates.T)\n",
    "ax[0].plot(latent_code[0])\n",
    "ax[0].plot(latent_code[1])\n",
    "ax[0].plot(latent_code[-1])\n",
    "ax[2].set_xlabel(\"time steps\")\n",
    "ax[0].set_xticks([])\n",
    "ax[1].set_xticks([])\n",
    "ax[2].set_ylabel(\"neurons\")\n",
    "ax[1].set_ylabel(\"neuron activity\")\n",
    "ax[0].set_ylabel('latents')\n",
    "#ax[1].set_xlim([0,79])\n",
    "#ax[2].set_xlim([0,79])\n",
    "ax[0].spines[['right','top']].set_visible(False)\n",
    "ax[1].spines[['right','top']].set_visible(False)\n",
    "ax[2].spines[['right','top']].set_visible(False)\n",
    "\n",
    "\n",
    "fig,ax = plt.subplots(3,figsize=(7,5))\n",
    "ax[2].imshow(rates[:,:100],aspect='auto',interpolation='none')\n",
    "ax[1].plot(rates[:,:100].T)\n",
    "ax[0].plot(latent_code[0,:100])\n",
    "ax[0].plot(latent_code[1,:100])\n",
    "ax[0].plot(latent_code[-1,:100])\n",
    "ax[2].set_xlabel(\"time steps\")\n",
    "ax[0].set_xticks([])\n",
    "ax[1].set_xticks([])\n",
    "ax[2].set_ylabel(\"neurons\")\n",
    "ax[1].set_ylabel(\"neuron activity\")\n",
    "ax[0].set_ylabel('latents')\n",
    "#ax[1].set_xlim([0,79])\n",
    "#ax[2].set_xlim([0,79])\n",
    "ax[0].spines[['right','top']].set_visible(False)\n",
    "ax[1].spines[['right','top']].set_visible(False)\n",
    "ax[2].spines[['right','top']].set_visible(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7de9c04",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "rates = task.data[:,tr_i]\n",
    "latent_code = task.latents[:,tr_i]\n",
    "fig,ax = plt.subplots(1,figsize=(7,5))\n",
    "T1 =0\n",
    "T2 = -1\n",
    "ax.plot(latent_code[0,T1:T2].numpy(),latent_code[1,T1:T2].numpy())\n",
    "ax.spines[['right','top']].set_visible(False)\n",
    "\n",
    "\n",
    "fig,ax = plt.subplots(1,figsize=(7,5))\n",
    "T1 = 0\n",
    "T2 =-1\n",
    "n_obs = 5\n",
    "for i in range(n_obs):\n",
    "    ax.plot(rates[i,T1:T2].T+i*2)\n",
    "\n",
    "\n",
    "ax.spines[['right','top']].set_visible(False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e40ab22",
   "metadata": {},
   "source": [
    "## Create a VAE RNN setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d465fb1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dim_z = 2\n",
    "dim_N=20\n",
    "dim_x = task_params['n_neurons']\n",
    "bs = 10\n",
    "cuda = True\n",
    "n_epochs = 1000\n",
    "wandb=False\n",
    "# initialise encoder\n",
    "enc_params = {\n",
    "    \"init_kernel_sizes\":[11,5,5,3],\n",
    "    \"strides\":[1]*4,\n",
    "    \"padding\":'valid',#'same' or 'valid'\n",
    "    \"padding_mode\":'circular', #'zeros' or 'circular' or 'reflect'\n",
    "    \"nonlinearity\":'gelu',\n",
    "    \"dilations\":[1]*4,\n",
    "    \"n_channels\":[24,16,16,16],\n",
    "    \"n_hidden\":24,\n",
    "    \"init_scale\":.1,\n",
    "    \"constant_var\":True,\n",
    "    \"obs_grad\":False,\n",
    "\n",
    "    }\n",
    "\n",
    "# initialise prior\n",
    "prior_params={\n",
    "    \"train_noise_obs\":True,\n",
    "    \"train_noise_prior\":True,\n",
    "    \"train_noise_prior_t0\":True,\n",
    "    \"init_noise_z\":.1,\n",
    "    \"init_noise_z_t0\":1,\n",
    "    \"init_noise_x\":task_params['R_x'],\n",
    "    \"scalar_noise_z\":\"Cov\",\n",
    "    \"scalar_noise_x\":True,\n",
    "    \"scalar_noise_z_t0\":\"Cov\",\n",
    "    'identity_readout':True,\n",
    "    'activation':\"relu\",\n",
    "    \"exp_par\":True,\n",
    "    \"shared_tau\":True,\n",
    "    \"readout_rates\": task_params[\"out\"],\n",
    "    \"train_obs_bias\":False,\n",
    "    \"train_obs_weights\":False,\n",
    "    \"train_latent_bias\":False,\n",
    "    \"train_neuron_bias\":True,\n",
    "    \"orth\":False,\n",
    "    \"m_norm\":False,\n",
    "    \"weight_dist\":\"uniform\",\n",
    "    \"weight_scaler\":1,#/dim_N,\n",
    "    'initial_state':'trainable'\n",
    "    }\n",
    "\n",
    "\n",
    "training_params = {\n",
    "    \"lr\":1e-3,\n",
    "    \"lr_end\":1e-5,\n",
    "    \"opt_eps\":1e-8,\n",
    "    \"CosineRestarts\":0,\n",
    "    \"beta\":0.5,\n",
    "    \"n_epochs\":n_epochs,\n",
    "    \"regularisation\":\"None\",\n",
    "    \"regularisation_params\": [1e-4,0],#ratio, lambda\n",
    "    \"annealing\":False,\n",
    "    \"annealing_epochs\":500,\n",
    "    \"grad_norm\":0,\n",
    "    \"bootstrap\":False,\n",
    "    \"eval_epochs\":50,\n",
    "    \"batch_size\":bs,\n",
    "    \"cuda\":cuda,\n",
    "    'smoothing':20,\n",
    "    'freq_cut_off':10000,\n",
    "    \"sim_obs_noise\":0,\n",
    "    \"sim_latent_noise\":1,\n",
    "    \"k\":64,\n",
    "    \"loss_f\":\"opt_VGTF\",\n",
    "    \"MC_q\":True,\n",
    "    \"dreg_q\":\"none\",\n",
    "    \"MC_p\":True,\n",
    "    \"dreg_p\":\"none\",\n",
    "    \"resample\":\"systematic\",#, multinomial or none\"\n",
    "    \"L2_reg\":0,\n",
    "    \"observation_likelihood\": \"Gauss\", # observation likelihood\n",
    "    \"bootstrap\":False,\n",
    "    \"alpha\":.25,\n",
    "    \"alpha_decay\":.999,\n",
    "    \"alpha_method\":\"constant\",\n",
    "    \"alpha_update_interval\":5,\n",
    "    \"run_eval\":True,\n",
    "    \"smooth_at_eval\":False\n",
    "}\n",
    "\n",
    "\n",
    "\n",
    "VAE_params = {\n",
    "    \"dim_x\":dim_x, \n",
    "    \"dim_z\":dim_z,\n",
    "    \"dim_N\":dim_N,\n",
    "    \"enc_architecture\":\"CNN\",\n",
    "    \"enc_params\":enc_params,\n",
    "    \"prior_architecture\":\"PLRNN\",\n",
    "    \"prior_params\":prior_params,\n",
    "    \"causal\":True,\n",
    "    }\n",
    "vae=VAE(VAE_params)\n",
    "\n",
    "#with torch.no_grad():\n",
    "    #vae.prior.transition.m.copy_(U)\n",
    "    #vae.prior.transition.m.requires_grad=False\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86916eb6",
   "metadata": {},
   "source": [
    "## Train the VAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ed9b677",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(vae.prior.std_embed_z(vae.prior.R_z))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a426e72b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(vae.prior.var_embed_x(vae.prior.R_x))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0bfc852",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(3):\n",
    "    vae=VAE(VAE_params)\n",
    "\n",
    "    train_VAE(\n",
    "        vae,\n",
    "        training_params,\n",
    "        task,\n",
    "        sync_wandb=wandb,\n",
    "        out_dir=out_dir,\n",
    "        fname=None\n",
    "        )\n",
    "\n",
    "    m_or = vae.prior.transition.m #20,2\n",
    "    n_or = vae.prior.transition.n\n",
    "    J = m_or@n_or\n",
    "    u,s,v = torch.linalg.svd(J)\n",
    "    projection_matrix = u[:,:2].T@m_or\n",
    "    proj_chol = projection_matrix@vae.prior.chol_cov_embed(vae.prior.R_z)\n",
    "    print(torch.sqrt(proj_chol@proj_chol.T))\n",
    "    save_model(vae, training_params, task_params,name =\"../models/SW20_\" + str(i))#, directory = '../models/')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a760675",
   "metadata": {},
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "114f8fd8",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "638282cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "m_or = vae.prior.transition.m #20,2\n",
    "n_or = vae.prior.transition.n\n",
    "J = m_or@n_or\n",
    "u,s,v = torch.linalg.svd(J)\n",
    "projection_matrix = u[:,:2].T@m_or\n",
    "proj_chol = projection_matrix@vae.prior.chol_cov_embed(vae.prior.R_z)\n",
    "print(torch.sqrt(proj_chol@proj_chol.T))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4476a84",
   "metadata": {},
   "outputs": [],
   "source": [
    "B = vae.prior.transition.m.T\n",
    "eff_var_prior = vae.prior.full_cov_embed(vae.prior.R_z)\n",
    "eff_var_x_diag=torch.eye(20,device='cuda')*0.01\n",
    "Kalman_gain = eff_var_prior@B@torch.linalg.inv(eff_var_x_diag+B.T@eff_var_prior@B)\n",
    "alpha =Kalman_gain@B.T\n",
    "print(alpha)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88ec8858",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(torch.mean(alpha))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f15c5dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.prior.full_cov_embed(vae.prior.R_z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "099bcf40",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(vae.prior.std_embed_z(vae.prior.R_z))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5a04b0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.prior.get_initial_state()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c28b5cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save_model(vae, training_params, task_params,name =\"../models/SW20_20\")#, directory = '../models/')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7872a4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#load_model('SW2')#'SW05_3000'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0a5bc36",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3ed1a60",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(vae.prior.std_embed_x(vae.prior.R_x)[0])\n",
    "print(vae.prior.std_embed_z(vae.prior.R_z)[0])\n",
    "print(vae.prior.std_embed_z(vae.prior.R_z_t0)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "089dba02",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.exp(vae.encoder.logvar_t0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87e94e7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "cl = vae.encoder.cut_len\n",
    "i=np.random.randint(0,bs)\n",
    "with torch.no_grad():\n",
    "    data= task.data[:,i]\n",
    "    latents= task.latents[:,i]\n",
    "\n",
    "    dim_x,T = data.shape\n",
    "    z0 = vae.encoder.mean_t0(data[:,0].unsqueeze(0))\n",
    "plt.figure()\n",
    "plt.plot(latents.detach().cpu().numpy()[0,cl:],label = 'true', color= 'orange')\n",
    "plt.plot(latents.detach().cpu().numpy()[1,cl:],label = 'true', color= 'blue')\n",
    "\n",
    "zs=[]\n",
    "\n",
    "for i in range(10):\n",
    "    Z = vae.prior.get_latent_time_series(time_steps=T-cl, z0=z0)\n",
    "    plt.plot(Z.detach().cpu().numpy()[0,0,:,0], color ='black',alpha=.2)\n",
    "    plt.plot(Z.detach().cpu().numpy()[0,1,:,0], color ='black',alpha=.2)\n",
    "    zs.append(Z.detach().cpu().numpy()[0,:,:,0])\n",
    "\n",
    "plt.plot(np.mean(zs,axis=0)[0], color ='red',label = 'infer')\n",
    "plt.plot(np.mean(zs,axis=0)[1], color ='teal',label = 'infer')\n",
    "\n",
    "data_gen = vae.prior.get_observation(Z,noise_scale=1).permute(0,2,1,3).reshape(T-cl, dim_x)\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(data.detach().cpu().numpy()[0,cl:], color= 'orange')\n",
    "plt.plot(data.detach().cpu().numpy()[1,cl:],color='blue')\n",
    "plt.plot(data_gen.detach().cpu().numpy()[:,0], color='red')\n",
    "plt.plot(data_gen.detach().cpu().numpy()[:,1],color='teal')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38475b9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dur = 100\n",
    "with torch.no_grad():\n",
    "    data= task.__getitem__(0)\n",
    "    dim_x,_ = data.shape\n",
    "    z_hat, Emean,Esigma,eps_s = vae.encoder(data.unsqueeze(0))\n",
    "    print(torch.mean(torch.exp(Esigma/2)))\n",
    "    z0 = z_hat[:,:,:1].squeeze()\n",
    "    Z = vae.prior.get_latent_time_series(time_steps=dur, z0=z0)\n",
    "    data_gen = vae.prior.get_observation(Z,noise_scale=1).permute(0,2,1,3).reshape(dur, dim_x)\n",
    "plt.figure()\n",
    "plt.plot(Z[0,:,:,0].detach().cpu().T);\n",
    "plt.xlim(0)\n",
    "plt.figure()\n",
    "plt.plot(data_gen.detach().cpu());\n",
    "plt.figure()\n",
    "color =['orange','red']\n",
    "for i in range(2):\n",
    "    plt.plot(Emean[0,i,:,0].detach().cpu().T,color=color[i]);\n",
    "    plt.plot(Z[0,i,:z_hat.shape[2],0].detach().cpu().T,ls='--',color=color[i]);\n",
    "plt.plot(np.exp(Esigma[0,0,:,0].detach().cpu()).T,color='blue');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad49d68d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.exp(Esigma[0,0,:,0].detach().cpu()/2).T,color='blue');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "626f4cff",
   "metadata": {},
   "outputs": [],
   "source": [
    "dur"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c55d47f",
   "metadata": {},
   "outputs": [],
   "source": [
    "z0=torch.ones(1,2,1,1)\n",
    "#z0[0,:,0,0]=\n",
    "\n",
    "dur = T2-T1\n",
    "dur = 300\n",
    "Z = vae.prior.get_latent_time_series(time_steps=dur, z0=z0)\n",
    "Zn = Z.cpu().detach().numpy()[0,:,:,0]\n",
    "data_gen = vae.prior.get_observation(Z,noise_scale=1).permute(0,2,1,3).reshape(dur, dim_x)\n",
    "\n",
    "rates = data_gen.cpu().detach().numpy().T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49c4b35e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def relu(x):\n",
    "    return np.maximum(x,0)\n",
    "\n",
    "prior = vae.prior.latent_step\n",
    "tau = prior.cast_A(prior.AW).cpu().detach().numpy().squeeze()\n",
    "V = (prior.n*prior.scaling).cpu().detach().numpy()\n",
    "U = prior.m_transform(prior.m).cpu().detach().numpy()\n",
    "B = prior.h.detach().cpu().numpy()\n",
    "\n",
    "def dyn_eq(x,y):\n",
    "    z = np.array([x,y])\n",
    "    zn=(tau)*z+V@relu(U@z-B)\n",
    "    h=.1\n",
    "    dz = (zn-z)/h\n",
    "    return dz[0],dz[1]\n",
    "\n",
    "xlim1=2\n",
    "xlim2=2\n",
    "ylims =2\n",
    "\n",
    "X, Y = np.meshgrid(np.linspace(-xlim1, xlim2, 30), np.linspace(-ylims, ylims, 30))\n",
    "u, v = np.zeros_like(X), np.zeros_like(X)\n",
    "NI, NJ = X.shape\n",
    "\n",
    "norm = np.zeros((NI,NJ))\n",
    "for i in range(NI):\n",
    "    for j in range(NJ):\n",
    "        x, y = X[i, j], Y[i, j]\n",
    "        dx, dy = dyn_eq(x,y)\n",
    "        u[i,j] = dx\n",
    "        v[i,j] = dy\n",
    "        norm[i,j]=np.log(np.linalg.norm([dx,dy]))\n",
    "\n",
    "with mpl.rc_context(fname='matplotlibrc'):\n",
    "\n",
    "    fig,ax = plt.subplots(1,2,figsize=(5,2))\n",
    "\n",
    "    ax[0].imshow(norm,extent = [-xlim1,xlim2,-ylims,ylims], \n",
    "            origin ='lower',cmap ='bone',vmax=np.max(norm),aspect='auto')\n",
    "    ax[0].streamplot(X, Y, u, v,color='lavender')\n",
    "    #plt.axis('square')\n",
    "    #plt.axis([-xlims, xlims, -ylims, ylims])\n",
    "    T1 = 0\n",
    "    T2 = 300\n",
    "    ax[0].set_box_aspect(1)\n",
    "    ax[0].spines[['right','top']].set_visible(False)\n",
    "\n",
    "    ax[0].set_xlim(-xlim1,xlim2)\n",
    "    ax[0].set_ylim(-ylims,ylims)\n",
    "    ax[0].set_xticks([-xlim1,0,xlim2])\n",
    "    ax[0].set_yticks([-ylims,0,ylims])\n",
    "    ax[0].set_xlabel(r\"$z_1$\")\n",
    "    ax[0].set_ylabel(r'$z_2$')\n",
    "    ax[0].plot(Zn[0,T1:T2],Zn[1,T1:T2],color='purple',lw=3)\n",
    "\n",
    "\n",
    "\n",
    "    for i in range(10):\n",
    "        ax[1].plot(rates[i,T1:T2]+i*2, color='black',alpha=.7)\n",
    "    ax[1].set_box_aspect(1)\n",
    "    ax[1].set_xlim(0,T2-T1)\n",
    "    ax[1].spines[['right','top']].set_visible(False)\n",
    "    ax[1].set_yticks(np.arange(0,20,2))\n",
    "    ax[1].set_yticklabels([])\n",
    "    ax[1].set_xlabel(\"timesteps\")\n",
    "    ax[1].set_ylabel(\"neurons\")\n",
    "    ax[0].set_title(\"latents\")\n",
    "    ax[1].set_title(\"observed\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "284c09e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5f7fed9",
   "metadata": {},
   "outputs": [],
   "source": [
    "rates.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d614bae",
   "metadata": {},
   "outputs": [],
   "source": [
    "T1 = 0\n",
    "T2 = 300\n",
    "latent_code = vae.prior.get_latent_time_series(time_steps=T2-T1, z0=z0)\n",
    "rates = vae.prior.get_observation(latent_code,noise_scale=1).permute(0,2,1,3).reshape(T2-T1, dim_x).detach().cpu().numpy()\n",
    "latent_code = latent_code.detach().cpu().numpy()\n",
    "\n",
    "\n",
    "fig,ax = plt.subplots(1,figsize=(7,5))\n",
    "ax.plot(latent_code[0,0,T1:T2],latent_code[0,1,T1:T2])\n",
    "ax.spines[['right','top']].set_visible(False)\n",
    "\n",
    "\n",
    "fig,ax = plt.subplots(1,figsize=(7,5))\n",
    "\n",
    "n_obs = 5\n",
    "\n",
    "for i in range(n_obs):\n",
    "    ax.plot(rates[T1:T2,i].T+i*2)\n",
    "\n",
    "\n",
    "ax.spines[['right','top']].set_visible(False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87b48416",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c2e9381",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e8d0226",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4702d9d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7df8f27e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9fd2ff3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9d2e1fb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae9389d9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86c564f2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
