{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## set up imports etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"7\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "sys.path.append('../')\n",
    "\n",
    "from rnn.vae import VAE\n",
    "from rnn.train import train_VAE\n",
    "from rnn.saving import save_model, load_model\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import h5py\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "from pathlib import Path\n",
    "\n",
    "mpl.rc_file(\"matplotlibrc\")\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load model and dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_ROOT = Path(\"../../\").absolute() / \"data\" / \"processed\"\n",
    "RUN_ROOT = Path(\"../../\").absolute() / \"runs\"\n",
    "CHKPT_DIR = Path(\"../../\").absolute() / \"checkpoints\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chkpt = \"maze_nocue_nc_5d/\"\n",
    "chkpt_path = CHKPT_DIR / chkpt\n",
    "\n",
    "model_save_name = sorted(chkpt_path.glob(\"*.pkl\"))[0].stem\n",
    "\n",
    "# load model\n",
    "suffixes = [\n",
    "    \"_state_dict_enc\", \n",
    "    \"_state_dict_prior\", \n",
    "    \"_task_params\", \n",
    "    \"_training_params\", \n",
    "    \"_vae_params\", \n",
    "]\n",
    "for suffix in suffixes:\n",
    "    model_save_name = model_save_name.replace(suffix, \"\")\n",
    "model_save_name = str(chkpt_path / model_save_name)\n",
    "vae, vae_params, task_params, training_params = load_model(model_save_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_dataset(\n",
    "    name: str,\n",
    "    phase: str = \"val\",\n",
    "    bin_size: int = 5,\n",
    "    t_forward: int = 0,\n",
    "    input_field: str = None,\n",
    "    u: int = 0,\n",
    "):\n",
    "    data_path = DATA_ROOT / name / phase\n",
    "    train_inputs = eval_inputs = None\n",
    "    normalization = None\n",
    "    with h5py.File(data_path / f\"train_input_{bin_size}ms.h5\", \"r\") as h5f:\n",
    "        if t_forward > 0:\n",
    "            train_data = np.concatenate([\n",
    "                np.concatenate([\n",
    "                    h5f['train_spikes_heldin'][()],\n",
    "                    h5f['train_spikes_heldin_forward'][()],\n",
    "                ], axis=1),\n",
    "                np.concatenate([\n",
    "                    h5f['train_spikes_heldout'][()],\n",
    "                    h5f['train_spikes_heldout_forward'][()],\n",
    "                ], axis=1),\n",
    "            ], axis=2)\n",
    "        else:\n",
    "            train_data = np.concatenate([h5f['train_spikes_heldin'][()], h5f['train_spikes_heldout'][()]], axis=2)\n",
    "        if input_field is not None:\n",
    "            train_inputs = h5f[f\"train_{input_field}\"][()]\n",
    "            if len(train_inputs.shape) == 1:\n",
    "                train_inputs = train_inputs[:, None]\n",
    "            if len(train_inputs.shape) == 2:\n",
    "                train_inputs = np.tile(train_inputs[:, None, :], reps=(1, train_data.shape[1], 1))\n",
    "            if train_inputs.shape[-1] > u:\n",
    "                train_inputs = train_inputs[:, :, :u]\n",
    "            normalization = np.max(np.abs(train_inputs), axis=(0,1), keepdims=True)\n",
    "            train_inputs = train_inputs / normalization\n",
    "    with h5py.File(data_path / f\"eval_input_{bin_size}ms.h5\", \"r\") as h5f:\n",
    "        assert \"eval_spikes_heldout\" in h5f.keys()\n",
    "        eval_data = np.concatenate([h5f['eval_spikes_heldin'][()], h5f['eval_spikes_heldout'][()]], axis=2)\n",
    "        if input_field is not None:\n",
    "            eval_inputs = h5f[f\"eval_{input_field}\"][()]\n",
    "            if len(eval_inputs.shape) == 1:\n",
    "                eval_inputs = eval_inputs[:, None]\n",
    "            if len(eval_inputs.shape) == 2:\n",
    "                eval_inputs = np.tile(eval_inputs[:, None, :], reps=(1, eval_data.shape[1], 1))\n",
    "            if eval_inputs.shape[-1] > u:\n",
    "                eval_inputs = eval_inputs[:, :, :u]\n",
    "            assert normalization is not None\n",
    "            eval_inputs = eval_inputs / normalization\n",
    "    return (\n",
    "        torch.tensor(train_data, dtype=torch.float), \n",
    "        torch.tensor(eval_data, dtype=torch.float),\n",
    "        torch.tensor(train_inputs, dtype=torch.float) if train_inputs is not None else train_inputs,\n",
    "        torch.tensor(eval_inputs, dtype=torch.float) if eval_inputs is not None else eval_inputs,        \n",
    "    )\n",
    "\n",
    "\n",
    "class NLBDataset(Dataset):\n",
    "    def __init__(self, data, params, inputs=None):\n",
    "        self.data = data\n",
    "        self.data_eval = data # Not used in this setup\n",
    "        self.trial_dur = data.shape[0]\n",
    "        self.task_params = params\n",
    "        self.stim = (\n",
    "            inputs.float().to(self.data.device)\n",
    "            if inputs is not None else \n",
    "            torch.zeros(data.shape[0], data.shape[1], 0, device=self.data.device, dtype=torch.float)\n",
    "        )\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        return self.data[idx].T, self.stim[idx].T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data, eval_data, train_inputs, eval_inputs = load_dataset(\n",
    "    name=\"mc_maze_input\",\n",
    "    phase=\"cue\",\n",
    "    bin_size=20,\n",
    "    t_forward=0,\n",
    "    input_field=\"input\",\n",
    "    u=2,\n",
    ")\n",
    "train_dataset = NLBDataset(train_data, dict(name=\"mc_maze_input\"), inputs=train_inputs)\n",
    "eval_dataset = NLBDataset(eval_data, dict(name=\"mc_maze_input\"), inputs=eval_inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## generate latents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_z = vae.vae_params[\"dim_z\"]\n",
    "n_inp = vae.vae_params[\"dim_u\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.to_device(\"cpu\")\n",
    "\n",
    "noise =0\n",
    "dur = 35\n",
    "t_on = 0\n",
    "t_of = 35\n",
    "z0=torch.randn(1,dim_z,1).cuda()*.1\n",
    "u = torch.zeros(1,n_inp,dur,device='cpu')\n",
    "\n",
    "u[0,1,t_on:t_of]=-1\n",
    "u[0,0,t_on:t_of]=1\n",
    "\n",
    "with torch.no_grad():\n",
    "    for _ in range(1):\n",
    "        z = vae.prior.get_latent_time_series(dur,z0=z0,u=u,noise_scale=noise).cpu().numpy()\n",
    "        plt.plot(z[0,:,:,0].T,color='black',alpha=.6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "def orthogonalise_network(vae):\n",
    "    with torch.no_grad():\n",
    "        m_or = vae.prior.transition.m #20,2\n",
    "        n_or = vae.prior.transition.n\n",
    "        J = m_or@n_or\n",
    "        u,s,v = torch.linalg.svd(J)\n",
    "        projection_matrix = u[:,:vae.dim_z].T@m_or\n",
    "        proj_chol = projection_matrix@torch.diag(vae.prior.std_embed_z(vae.prior.R_z))\n",
    "        m_new = u[:,:vae.dim_z]\n",
    "        n_new = (v[:vae.dim_z].T * s[:vae.dim_z]).T\n",
    "        vae.prior.transition.m.copy_(m_new)\n",
    "        vae.prior.transition.n.copy_(n_new)\n",
    "        vae.prior.chol_cov_embed = lambda x: torch.tril(x)\n",
    "        vae.prior.R_z=torch.nn.Parameter(torch.linalg.cholesky(proj_chol@proj_chol.T))\n",
    "        vae.prior.params['scalar_noise_z']=\"Cov\"\n",
    "    return vae\n",
    "\n",
    "vae_orth = copy.deepcopy(vae)\n",
    "vae_orth = orthogonalise_network(vae_orth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    m_or = vae.prior.transition.m #20,2\n",
    "    n_or = vae.prior.transition.n\n",
    "    J = m_or@n_or\n",
    "    u,s,v = torch.linalg.svd(J)\n",
    "    projection_matrix = u[:,:vae.dim_z].T@m_or\n",
    "projection_matrix = projection_matrix.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "noise =0\n",
    "dur = 35\n",
    "t_on = 0\n",
    "t_of = 35\n",
    "z0=torch.randn(1,dim_z,1).cuda()*.1\n",
    "u = torch.zeros(1,n_inp,dur,device='cpu')\n",
    "\n",
    "u[0,1,t_on:t_of]=-1\n",
    "u[0,0,t_on:t_of]=1\n",
    "\n",
    "with torch.no_grad():\n",
    "    for _ in range(1):\n",
    "        z_orth = vae_orth.prior.get_latent_time_series(dur,z0=z0,u=u,noise_scale=noise).cpu().numpy()\n",
    "        plt.plot(z_orth[0,:,:,0].T,color='black',alpha=.6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## infer NLB latents from data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.to_device('cuda')\n",
    "marginal_smoothing = False\n",
    "k = training_params['k']\n",
    "\n",
    "i = 0\n",
    "x = train_dataset.data[i].T.unsqueeze(0)\n",
    "u = train_dataset.stim[i].T.unsqueeze(0)\n",
    "t_held_in = x.shape[2]\n",
    "with torch.no_grad():\n",
    "    Qzs_filt, Qzs_sm, Xs_filt, Xs_sm= vae.predict_NLB(x.cuda(), u=u.cuda(),k=k,t_held_in = t_held_in, t_forward=0,marginal_smoothing=marginal_smoothing)\n",
    "plt.plot(Qzs_filt[0,:,:,0].cpu().T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_repeats = 10\n",
    "batch_size = 64\n",
    "k = 100\n",
    "device = \"cuda:0\"\n",
    "\n",
    "x_t = train_dataset.data.permute(0,2,1).to(device)\n",
    "u_t = train_dataset.stim.permute(0,2,1).to(device)\n",
    "t_held_in = x_t.shape[2]\n",
    "\n",
    "Qzs_filt_repeats = []\n",
    "Qzs_sm_repeats = []\n",
    "Xs_filt_repeats = []\n",
    "Xs_sm_repeats = []\n",
    "for i in range(n_repeats):\n",
    "    with torch.no_grad():\n",
    "        Qzs_filt, Qzs_sm, Xs_filt, Xs_sm = zip(*[\n",
    "            vae.predict_NLB(x_t_chunk, u=u_t_chunk, k=k, t_held_in = t_held_in, t_forward=0)\n",
    "            for x_t_chunk, u_t_chunk in zip(\n",
    "                torch.chunk(x_t, chunks=len(x_t) // batch_size, dim=0),\n",
    "                torch.chunk(u_t, chunks=len(u_t) // batch_size, dim=0),\n",
    "            )\n",
    "        ])\n",
    "        Qzs_filt_repeats.append(torch.cat(Qzs_filt, dim=0).cpu().numpy().mean(axis=-1).transpose(0, 2, 1))\n",
    "        Qzs_sm_repeats.append(torch.cat(Qzs_sm, dim=0).cpu().numpy().mean(axis=-1).transpose(0, 2, 1))\n",
    "        Xs_filt_repeats.append(torch.cat(Xs_filt, dim=0).cpu().numpy().mean(axis=-1).transpose(0, 2, 1))\n",
    "        Xs_sm_repeats.append(torch.cat(Xs_sm, dim=0).cpu().numpy().mean(axis=-1).transpose(0, 2, 1))\n",
    "\n",
    "Qzs_filt = np.mean(Qzs_filt_repeats, axis=0)\n",
    "Qzs_sm = np.mean(Qzs_sm_repeats, axis=0)\n",
    "Xs_filt = np.mean(Xs_filt_repeats, axis=0)\n",
    "Xs_sm = np.mean(Xs_sm_repeats, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_pos = train_dataset.stim[:, 0, :].detach().cpu().numpy()\n",
    "angles = np.arctan2(target_pos[:, 1], target_pos[:, 0])\n",
    "angles = angles / (2 * np.pi) + 0.5\n",
    "angles = (np.round(angles * 8) % 8) / 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(target_pos[:, 0], target_pos[:, 1], c=angles, cmap=plt.cm.hsv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Qz_mean = np.empty((len(np.unique(angles)), *Qzs_filt.shape[1:]))\n",
    "for i, angle in enumerate(np.sort(np.unique(angles))):\n",
    "    mask = (angles == angle)\n",
    "    Qz_mean[i] = Qzs_filt[mask].mean(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 0\n",
    "plt_end=35\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for angle_i in range(len(np.sort(np.unique(angles)))):\n",
    "    angle = np.sort(np.unique(angles))[angle_i]\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(Qz_mean[angle_i,plt_start:plt_end,i],alpha = .5, color=plt.cm.hsv(angle))\n",
    "            else:\n",
    "                ax[i,j].plot(\n",
    "                    Qz_mean[angle_i,plt_start:plt_end,i],\n",
    "                    Qz_mean[angle_i,plt_start:plt_end,j],\n",
    "                    alpha = .5, color=plt.cm.hsv(angle)\n",
    "                )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Qz_mean_orth = Qz_mean @ projection_matrix\n",
    "\n",
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 10\n",
    "plt_end=35\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for angle_i in range(len(np.sort(np.unique(angles)))):\n",
    "    angle = np.sort(np.unique(angles))[angle_i]\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(Qz_mean_orth[angle_i,plt_start:plt_end,i],alpha = .5, color=plt.cm.hsv(angle))\n",
    "            else:\n",
    "                ax[i,j].plot(\n",
    "                    Qz_mean_orth[angle_i,plt_start:plt_end,i],\n",
    "                    Qz_mean_orth[angle_i,plt_start:plt_end,j],\n",
    "                    alpha = .5, color=plt.cm.hsv(angle)\n",
    "                )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Qz_mean_orth = Qz_mean @ projection_matrix\n",
    "\n",
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 2\n",
    "plt_end= 5\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for angle_i in range(len(np.sort(np.unique(angles)))):\n",
    "    angle = np.sort(np.unique(angles))[angle_i]\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(Qz_mean_orth[angle_i,plt_start:plt_end,i],alpha = .5, color=plt.cm.hsv(angle))\n",
    "            else:\n",
    "                ax[i,j].plot(\n",
    "                    Qz_mean_orth[angle_i,plt_start:plt_end,i],\n",
    "                    Qz_mean_orth[angle_i,plt_start:plt_end,j],\n",
    "                    alpha = .5, color=plt.cm.hsv(angle)\n",
    "                )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Qz_mean_orth = Qz_mean @ projection_matrix\n",
    "Qzs_orth_all = Qzs_filt @ projection_matrix\n",
    "plt_start = 6\n",
    "plt_end = 10\n",
    "plt_dim1 = 0\n",
    "plt_dim2 = 1\n",
    "sign_dim1 = 1\n",
    "sign_dim2 = -1\n",
    "fig, axs = plt.subplots(1, 1, figsize=(1.0, 1.0))\n",
    "for angle_i in range(len(np.sort(np.unique(angles)))):\n",
    "    angle = np.sort(np.unique(angles))[angle_i]\n",
    "    \n",
    "    mask = (angles == angle)\n",
    "    Qzs_orth_all_cond = Qzs_orth_all[mask]\n",
    "    for trial in range(min(Qzs_orth_all_cond.shape[0], 5)):\n",
    "        axs.plot(\n",
    "            # np.concatenate([initial_state_proj[plt_dim1:plt_dim1+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim1]]),\n",
    "            # np.concatenate([initial_state_proj[plt_dim2:plt_dim2+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim2]]), \n",
    "            Qzs_orth_all_cond[trial,plt_start:plt_end,plt_dim1] * sign_dim1,\n",
    "            Qzs_orth_all_cond[trial,plt_start:plt_end,plt_dim2] * sign_dim2,\n",
    "            color=plt.cm.hsv(angle), alpha=0.3, linewidth=0.6)\n",
    "            \n",
    "    axs.plot(\n",
    "        Qz_mean_orth[angle_i,plt_start:plt_end,plt_dim1] * sign_dim1,\n",
    "        Qz_mean_orth[angle_i,plt_start:plt_end,plt_dim2] * sign_dim2,\n",
    "        color=plt.cm.hsv(angle), alpha=0.6)\n",
    "axs.scatter(Qz_mean_orth[:,plt_end-1,plt_dim1] * sign_dim1,Qz_mean_orth[:,plt_end-1,plt_dim2] * sign_dim2, \n",
    "            s=20, color=[plt.cm.hsv(a) for a in np.sort(np.unique(angles))], edgecolor=\"black\", linewidth=0.8, zorder=2, alpha=0.6)\n",
    "axs.set_xticks([])\n",
    "axs.set_yticks([])\n",
    "axs.set_xlabel(f\"$z_{plt_dim1+1}$\")\n",
    "axs.set_ylabel(f\"$z_{plt_dim2+1}$\")\n",
    "\n",
    "# xlim = axs.get_xlim()\n",
    "# xrange = xlim[1] - xlim[0]\n",
    "# axs.set_xlim(xlim[0] + 0.04 * xrange, xlim[1] - 0.16 * xrange)\n",
    "\n",
    "ylim = axs.get_ylim()\n",
    "yrange = ylim[1] - ylim[0]\n",
    "axs.set_ylim(ylim[0] - 0.05 * yrange, ylim[1])\n",
    "\n",
    "prep_xlim, prep_ylim = axs.get_xlim(), axs.get_ylim()\n",
    "\n",
    "axs.set_title(\"pre-movement period\")\n",
    "plt.savefig(\"../../plots/prep_inf_latents.svg\")\n",
    "plt.savefig(\"../../plots/prep_inf_latents.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=(1.2, 1.2), subplot_kw={'projection': '3d'})\n",
    "\n",
    "Qz_mean_orth = Qz_mean @ projection_matrix\n",
    "Qzs_orth_all = Qzs_filt @ projection_matrix\n",
    "plt_start = 10\n",
    "plt_end = 35\n",
    "plt_dim1 = 4\n",
    "plt_dim2 = 3\n",
    "plt_dim3 = 2\n",
    "sign_dim1 = 1\n",
    "sign_dim2 = 1\n",
    "sign_dim3 = 1\n",
    "\n",
    "prop_cycle = [plt.cm.hsv(ang) for ang in np.sort(np.unique(angles))]\n",
    "\n",
    "for angle_i in range(len(np.sort(np.unique(angles)))):\n",
    "    angle = np.sort(np.unique(angles))[angle_i]\n",
    "        \n",
    "    mask = (angles == angle)\n",
    "    Qzs_orth_all_cond = Qzs_orth_all[mask]\n",
    "    for trial in range(min(Qzs_orth_all_cond.shape[0], 5)):\n",
    "        axs.plot(\n",
    "            Qzs_orth_all_cond[trial,plt_start:plt_end,plt_dim1]*sign_dim1,\n",
    "            Qzs_orth_all_cond[trial,plt_start:plt_end,plt_dim2]*sign_dim2,\n",
    "            Qzs_orth_all_cond[trial,plt_start:plt_end,plt_dim3]*sign_dim3,\n",
    "            color=prop_cycle[angle_i], alpha=0.3, linewidth=0.6)\n",
    "            \n",
    "    axs.plot(\n",
    "        Qz_mean_orth[angle_i,plt_start:plt_end,plt_dim1]*sign_dim1,\n",
    "        Qz_mean_orth[angle_i,plt_start:plt_end,plt_dim2]*sign_dim2,\n",
    "        Qz_mean_orth[angle_i,plt_start:plt_end,plt_dim3]*sign_dim3,\n",
    "        color=prop_cycle[angle_i], alpha=0.6)\n",
    "        \n",
    "axs.scatter(Qz_mean_orth[:,plt_start,plt_dim1]*sign_dim1,Qz_mean_orth[:,plt_start,plt_dim2]*sign_dim2, Qz_mean_orth[:,plt_start,plt_dim3]*sign_dim3,\n",
    "            s=15, color=prop_cycle, edgecolor=\"black\", linewidth=0.8, zorder=2, alpha=0.6)\n",
    "axs.scatter(Qz_mean_orth[:,plt_end-1,plt_dim1]*sign_dim1,Qz_mean_orth[:,plt_end-1,plt_dim2]*sign_dim2, Qz_mean_orth[:,plt_end-1,plt_dim3]*sign_dim3,\n",
    "            s=15, marker=\"^\", color=prop_cycle, edgecolor=\"black\", linewidth=0.8, zorder=2, alpha=0.6)\n",
    "axs.set_xticks([])\n",
    "axs.set_yticks([])\n",
    "axs.set_zticks([])\n",
    "axs.set_xlabel(f\"$z_{plt_dim1+1}$\")\n",
    "axs.set_ylabel(f\"$z_{plt_dim2+1}$\")\n",
    "axs.set_zlabel(f\"$z_{plt_dim3+1}$\")\n",
    "axs.xaxis.labelpad=-15\n",
    "axs.yaxis.labelpad=-15\n",
    "axs.zaxis.labelpad=-16\n",
    "\n",
    "# zlim = axs.get_zlim()\n",
    "# zrange = zlim[1] - zlim[0]\n",
    "# axs.set_zlim(zlim[0] - 0.1 * zrange, zlim[1] - 0.2 * zrange)\n",
    "\n",
    "mvt_xlim, mvt_ylim, mvt_zlim = axs.get_xlim(), axs.get_ylim(), axs.get_zlim()\n",
    "\n",
    "axs.set_title(\"movement period\", pad=-10)\n",
    "plt.savefig(\"../../plots/movement_inf_latents.svg\")\n",
    "plt.savefig(\"../../plots/movement_inf_latents.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# generate trajectories wth corresponding angles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def project_to_square(theta):\n",
    "    if theta > 180:\n",
    "        theta -= 360\n",
    "    elif theta < -180:\n",
    "        theta += 360\n",
    "    if (theta >= 45 and theta < 135):\n",
    "        y = 1.0\n",
    "        x = y * np.cos(theta / 180 * np.pi) / np.sin(theta / 180 * np.pi)\n",
    "    elif (theta >= 135 or theta < -135):\n",
    "        x = -1.0\n",
    "        y = x * np.sin(theta / 180 * np.pi) / np.cos(theta / 180 * np.pi)\n",
    "    elif (theta >= -135 and theta < -45):\n",
    "        y = -1.0\n",
    "        x = y * np.cos(theta / 180 * np.pi) / np.sin(theta / 180 * np.pi)\n",
    "    else:\n",
    "        x = 1.0\n",
    "        y = x * np.sin(theta / 180 * np.pi) / np.cos(theta / 180 * np.pi)\n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot 3 of the latents\n",
    "noise_scale = 1\n",
    "# targets = [\n",
    "#     [-1,0],\n",
    "#     [-1,-1],\n",
    "#     [0,-1],\n",
    "#     [1,-1],\n",
    "#     [1,0],\n",
    "#     [1,1],\n",
    "#     [0,1],\n",
    "#     [-1,1],\n",
    "# ]\n",
    "gen_angles = np.arange(-180, 180, 22.5)\n",
    "targets = [project_to_square(ang) for ang in gen_angles]\n",
    "dur = 35\n",
    "t_on=0\n",
    "t_of= 35\n",
    "R_z=0.05\n",
    "prop_cycle =[plt.cm.hsv(i) for i in np.arange(0, 1, 1/len(targets))]\n",
    "\n",
    "n_repeats = 50\n",
    "z_all =np.zeros((n_repeats,len(targets),dur,dim_z))\n",
    "for ri in range(n_repeats):\n",
    "    input = np.zeros((len(targets), dur,n_inp))\n",
    "    z=np.zeros((len(targets), dur,dim_z))\n",
    "    for i, target in enumerate(targets):\n",
    "        input[i, t_on:t_of,:]=np.array(target)[None, :]\n",
    "    input = torch.tensor(input, dtype=torch.float).permute(0, 2, 1).cuda()\n",
    "    z0 = Qzs_filt[np.random.permutation(Qzs_filt.shape[0])[:input.shape[0]], 0, :][:, :, None, None]\n",
    "    z0 = torch.tensor(z0, dtype=torch.float)\n",
    "    z = vae.prior.get_latent_time_series(dur,u=input,z0=z0,noise_scale=noise_scale).cpu().numpy()\n",
    "    z_all[ri]=z.transpose(0, 2, 1, 3).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_all.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_mean = np.mean(z_all,axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 8\n",
    "plt_end=10\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for i in range(dim_z):\n",
    "    for j in range(i,dim_z):\n",
    "        ax[i,j].set_prop_cycle('color',prop_cycle)\n",
    "for targ_i in range(len(targets)):\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(z_mean[targ_i,plt_start:plt_end,i],alpha = .5)\n",
    "            else:\n",
    "                ax[i,j].plot(z_mean[targ_i,plt_start:plt_end,i],z_mean[targ_i,plt_start:plt_end,j],alpha = .5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_orth_mean = z_mean @ projection_matrix\n",
    "\n",
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 8\n",
    "plt_end=10\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for i in range(dim_z):\n",
    "    for j in range(i,dim_z):\n",
    "        ax[i,j].set_prop_cycle('color',prop_cycle)\n",
    "for targ_i in range(len(targets)):\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(z_orth_mean[targ_i,plt_start:plt_end,i],alpha = .5)\n",
    "            else:\n",
    "                ax[i,j].plot(z_orth_mean[targ_i,plt_start:plt_end,i],z_orth_mean[targ_i,plt_start:plt_end,j],alpha = .5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_orth_all = z_all @ projection_matrix\n",
    "plt_start =6\n",
    "plt_end = 9\n",
    "plt_dim1 = 0\n",
    "plt_dim2 = 1\n",
    "sign_dim1 = 1\n",
    "sign_dim2 = -1\n",
    "fig, axs = plt.subplots(1, 1, figsize=(1.0, 1.0))\n",
    "prop_cycle =[plt.cm.hsv(i) for i in np.arange(0, 1, 1/len(targets))]\n",
    "for targ_i in range(len(targets)):\n",
    "    \n",
    "    for trial in range(min(z_orth_all.shape[0], 5)):\n",
    "        axs.plot(\n",
    "            # np.concatenate([initial_state_proj[plt_dim1:plt_dim1+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim1]]),\n",
    "            # np.concatenate([initial_state_proj[plt_dim2:plt_dim2+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim2]]), \n",
    "            z_orth_all[trial,targ_i,plt_start:plt_end,plt_dim1]*sign_dim1,\n",
    "            z_orth_all[trial,targ_i,plt_start:plt_end,plt_dim2]*sign_dim2,\n",
    "            color=prop_cycle[targ_i], alpha=0.3, linewidth=0.6)\n",
    "        \n",
    "    axs.plot(\n",
    "        # np.concatenate([initial_state_proj[plt_dim1:plt_dim1+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim1]]),\n",
    "        # np.concatenate([initial_state_proj[plt_dim2:plt_dim2+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim2]]), \n",
    "        z_orth_mean[targ_i,plt_start:plt_end,plt_dim1]*sign_dim1,\n",
    "        z_orth_mean[targ_i,plt_start:plt_end,plt_dim2]*sign_dim2,\n",
    "        color=prop_cycle[targ_i], alpha=0.6)\n",
    "axs.scatter(z_orth_mean[:,plt_end-1,plt_dim1]*sign_dim1,z_orth_mean[:,plt_end-1,plt_dim2]*sign_dim2, \n",
    "            s=20, color=prop_cycle, edgecolor=\"black\", linewidth=0.8, zorder=2, alpha=0.6)\n",
    "axs.set_xticks([])\n",
    "axs.set_yticks([])\n",
    "axs.set_xlabel(f\"$z_{plt_dim1+1}$\")\n",
    "axs.set_ylabel(f\"$z_{plt_dim2+1}$\")\n",
    "\n",
    "# xlim = axs.get_xlim()\n",
    "# xrange = xlim[1] - xlim[0]\n",
    "# axs.set_xlim(xlim[0] - 0.04 * xrange, xlim[1] + 0.04 * xrange)\n",
    "\n",
    "# ylim = axs.get_ylim()\n",
    "# yrange = ylim[1] - ylim[0]\n",
    "# axs.set_ylim(ylim[0] - 0.04 * yrange, ylim[1] + 0.04 * yrange)\n",
    "\n",
    "axs.set_xlim(*prep_xlim)\n",
    "axs.set_ylim(*prep_ylim)\n",
    "\n",
    "axs.set_title(\"pre-movement period\")\n",
    "plt.savefig(\"../../plots/prep_gen_latents.svg\")\n",
    "plt.savefig(\"../../plots/prep_gen_latents.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=(1.2, 1.2), subplot_kw={'projection': '3d'})\n",
    "\n",
    "z_orth_all = z_all @ projection_matrix\n",
    "plt_start = 10\n",
    "plt_end = 35\n",
    "plt_dim1 = 4\n",
    "plt_dim2 = 3\n",
    "plt_dim3 = 2\n",
    "sign_dim1 = 1\n",
    "sign_dim2 = 1\n",
    "sign_dim3 = 1\n",
    "\n",
    "strd = 2\n",
    "prop_cycle =[plt.cm.hsv(i) for i in np.arange(0, 1, 1/len(targets))]\n",
    "for targ_i in range(len(targets)):\n",
    "        \n",
    "    for trial in range(min(z_orth_all.shape[0], 5)):\n",
    "        axs.plot(\n",
    "            z_orth_all[trial,targ_i,plt_start:plt_end,plt_dim1]*sign_dim1,\n",
    "            z_orth_all[trial,targ_i,plt_start:plt_end,plt_dim2]*sign_dim2,\n",
    "            z_orth_all[trial,targ_i,plt_start:plt_end,plt_dim3]*sign_dim3,\n",
    "            color=prop_cycle[targ_i], alpha=0.3, linewidth=0.6)\n",
    "        \n",
    "    if targ_i % strd == 0:\n",
    "        axs.plot(\n",
    "            z_orth_mean[targ_i,plt_start:plt_end,plt_dim1]*sign_dim1,\n",
    "            z_orth_mean[targ_i,plt_start:plt_end,plt_dim2]*sign_dim2,\n",
    "            z_orth_mean[targ_i,plt_start:plt_end,plt_dim3]*sign_dim3,\n",
    "            color=prop_cycle[targ_i], alpha=0.6)\n",
    "        \n",
    "    # else:\n",
    "    #     axs.plot(\n",
    "    #         z_orth_mean[targ_i,plt_start:plt_end,plt_dim1]*sign_dim1,\n",
    "    #         z_orth_mean[targ_i,plt_start:plt_end,plt_dim2]*sign_dim2,\n",
    "    #         z_orth_mean[targ_i,plt_start:plt_end,plt_dim3]*sign_dim3,\n",
    "    #         color=prop_cycle[targ_i], alpha=0.6)\n",
    "        \n",
    "strd = 2\n",
    "axs.scatter(z_orth_mean[::strd,plt_start,plt_dim1]*sign_dim1,z_orth_mean[::strd,plt_start,plt_dim2]*sign_dim2, z_orth_mean[::strd,plt_start,plt_dim3]*sign_dim3,\n",
    "            s=15, color=prop_cycle[::strd], edgecolor=\"black\", linewidth=0.8, zorder=2, alpha=0.8)\n",
    "axs.scatter(z_orth_mean[::strd,plt_end-1,plt_dim1]*sign_dim1,z_orth_mean[::strd,plt_end-1,plt_dim2]*sign_dim2, z_orth_mean[::strd,plt_end-1,plt_dim3]*sign_dim3,\n",
    "            s=15, marker=\"^\", color=prop_cycle[::strd], edgecolor=\"black\", linewidth=0.8, zorder=2, alpha=0.8)\n",
    "axs.set_xticks([])\n",
    "axs.set_yticks([])\n",
    "axs.set_zticks([])\n",
    "axs.set_xlabel(f\"$z_{plt_dim1+1}$\")\n",
    "axs.set_ylabel(f\"$z_{plt_dim2+1}$\")\n",
    "axs.set_zlabel(f\"$z_{plt_dim3+1}$\")\n",
    "axs.xaxis.labelpad=-15\n",
    "axs.yaxis.labelpad=-15\n",
    "axs.zaxis.labelpad=-16\n",
    "\n",
    "axs.set_xlim(*mvt_xlim)\n",
    "axs.set_ylim(*mvt_ylim)\n",
    "axs.set_zlim(*mvt_zlim)\n",
    "\n",
    "axs.set_title(\"movement period\", pad=-10)\n",
    "plt.savefig(\"../../plots/movement_gen_latents.svg\")\n",
    "plt.savefig(\"../../plots/movement_gen_latents.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## decoded reaches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = DATA_ROOT / \"mc_maze_input\" / \"pos\" / \"eval_target_20ms.h5\"\n",
    "with h5py.File(data_path, \"r\") as h5f:\n",
    "    train_behavior = h5f[\"mc_maze_20\"][\"train_behavior\"][()]\n",
    "    eval_behavior = h5f[\"mc_maze_20\"][\"eval_behavior\"][()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import Ridge\n",
    "\n",
    "flatten2d = lambda arr: arr.reshape(-1, arr.shape[-1])\n",
    "rate_decoder = Ridge(alpha=1e-6)\n",
    "\n",
    "rate_decoder.fit(\n",
    "    flatten2d(Xs_filt[:400]),\n",
    "    flatten2d(train_behavior[:400]),\n",
    ")\n",
    "\n",
    "print(rate_decoder.score(flatten2d(Xs_filt[400:]), flatten2d(train_behavior[400:])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_decoder = Ridge(alpha=1e-6)\n",
    "\n",
    "latent_decoder.fit(\n",
    "    flatten2d(Qzs_filt),\n",
    "    flatten2d(train_behavior),\n",
    ")\n",
    "\n",
    "print(latent_decoder.score(flatten2d(Qzs_filt), flatten2d(train_behavior)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nlb_behavior_all = rate_decoder.predict(flatten2d(Xs_filt)).reshape(*Xs_filt.shape[:-1], -1)\n",
    "nlb_position_all = np.cumsum(nlb_behavior_all, axis=1) / 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=(1.0, 1.0))\n",
    "\n",
    "nlb_position_mean = np.empty((len(angles), *nlb_position_all.shape[1:]))\n",
    "for angle_i in range(len(np.sort(np.unique(angles)))):\n",
    "    angle = np.sort(np.unique(angles))[angle_i]        \n",
    "    mask = (angles == angle)\n",
    "    nlb_position_mean[angle_i] = nlb_position_all[mask].mean(axis=0)\n",
    "\n",
    "for angle_i in range(len(np.sort(np.unique(angles)))):\n",
    "    angle = np.sort(np.unique(angles))[angle_i]\n",
    "    axs.plot(\n",
    "        nlb_position_mean[angle_i,:,0],\n",
    "        nlb_position_mean[angle_i,:,1],\n",
    "        color=plt.cm.hsv(angle), alpha=0.6,\n",
    "        # linestyle=(\"--\" if targ_i == 2 else \"-\"),\n",
    "    )\n",
    "        \n",
    "    mask = (angles == angle)\n",
    "    nlb_position_cond = nlb_position_all[mask]\n",
    "    for trial in range(min(nlb_position_cond.shape[0], 10)):\n",
    "        axs.plot(\n",
    "            nlb_position_cond[trial,:,0],\n",
    "            nlb_position_cond[trial,:,1],\n",
    "            color=plt.cm.hsv(angle), alpha=0.3, linewidth=0.6)\n",
    "# axs.scatter(position_mean[:2,-1,0], position_mean[:2,-1,1], color=prop_cycle[:2], s=20, marker=\"s\", edgecolor=\"black\", linewidth=0.8, zorder=2)\n",
    "# axs.scatter(position_mean[2,-1,0], position_mean[2,-1,1], color=prop_cycle[2], marker=\"*\", s=28, edgecolor=\"black\", linewidth=0.8, zorder=2)\n",
    "# axs.scatter(position_mean[3:,-1,0], position_mean[3:,-1,1], color=prop_cycle[3:], s=20, marker=\"s\", edgecolor=\"black\", linewidth=0.8, zorder=2)\n",
    "axs.set_xticks([])\n",
    "axs.set_yticks([])\n",
    "axs.set_xlabel(\"x position\")\n",
    "axs.set_ylabel(\"y position\")\n",
    "axs.set_title(\"decoded reaches\")\n",
    "\n",
    "xlim = axs.get_xlim()\n",
    "xrange = xlim[1] - xlim[0]\n",
    "axs.set_xlim(xlim[0] + 0.1 * xrange, xlim[1] - 0.1 * xrange)\n",
    "\n",
    "# xlim = axs.get_xlim()\n",
    "# xrange = xlim[1] - xlim[0]\n",
    "# axs.set_xlim(xlim[0] - 0.1 * xrange, xlim[1] + 0.1 * xrange)\n",
    "\n",
    "dec_xlim = axs.get_xlim()\n",
    "dec_ylim = axs.get_ylim()\n",
    "\n",
    "plt.savefig(\"../../plots/decoded_inf_reaches.svg\")\n",
    "plt.savefig(\"../../plots/decoded_inf_reaches.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:0\"\n",
    "        \n",
    "x_all = np.empty((*z_all.shape[:-1], 182))\n",
    "for i in range(z_all.shape[0]):\n",
    "    with torch.no_grad():\n",
    "        x = vae.obs_rectify(\n",
    "            vae.prior.get_observation(\n",
    "                z=torch.tensor(z_all[i], dtype=torch.float, device=device).unsqueeze(-1).permute(0, 2, 1, 3)\n",
    "            )\n",
    "        ).cpu().numpy()\n",
    "        x_all[i] = x.transpose(0, 2, 1, 3).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_all.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "behavior_all = rate_decoder.predict(flatten2d(x_all)).reshape(*x_all.shape[:-1], -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "position_all = np.cumsum(behavior_all, axis=2) / 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "position_mean = position_all.mean(axis=0)\n",
    "fig, axs = plt.subplots(1, 1, figsize=(1.0, 1.0))\n",
    "strd = 2\n",
    "for targ_i in range(0, len(targets), 2):\n",
    "        \n",
    "    for trial in range(min(position_all.shape[0], 10)):\n",
    "        axs.plot(\n",
    "            position_all[trial,targ_i,:,0],\n",
    "            position_all[trial,targ_i,:,1],\n",
    "            color=prop_cycle[targ_i], alpha=0.3, linewidth=0.6)\n",
    "        \n",
    "    if targ_i % strd == 0:\n",
    "        axs.plot(\n",
    "            position_mean[targ_i,:,0],\n",
    "            position_mean[targ_i,:,1],\n",
    "            color=prop_cycle[targ_i], alpha=0.6,\n",
    "            # linestyle=(\"--\" if targ_i == 2 else \"-\"),\n",
    "        )\n",
    "# axs.scatter(position_mean[:2,-1,0], position_mean[:2,-1,1], color=prop_cycle[:2], s=20, marker=\"s\", edgecolor=\"black\", linewidth=0.8, zorder=2)\n",
    "# axs.scatter(position_mean[2,-1,0], position_mean[2,-1,1], color=prop_cycle[2], marker=\"*\", s=28, edgecolor=\"black\", linewidth=0.8, zorder=2)\n",
    "# axs.scatter(position_mean[3:,-1,0], position_mean[3:,-1,1], color=prop_cycle[3:], s=20, marker=\"s\", edgecolor=\"black\", linewidth=0.8, zorder=2)\n",
    "axs.set_xticks([])\n",
    "axs.set_yticks([])\n",
    "axs.set_xlabel(\"x position\")\n",
    "axs.set_ylabel(\"y position\")\n",
    "axs.set_title(\"decoded reaches\")\n",
    "\n",
    "# xlim = axs.get_xlim()\n",
    "# xrange = xlim[1] - xlim[0]\n",
    "# axs.set_xlim(xlim[0] + 0.1 * xrange, xlim[1] - 0.1 * xrange)\n",
    "\n",
    "# xlim = axs.get_xlim()\n",
    "# xrange = xlim[1] - xlim[0]\n",
    "# axs.set_xlim(xlim[0] - 0.1 * xrange, xlim[1] + 0.1 * xrange)\n",
    "\n",
    "axs.set_xlim(*dec_xlim)\n",
    "axs.set_ylim(*dec_ylim)\n",
    "\n",
    "plt.savefig(\"../../plots/decoded_gen_reaches.svg\")\n",
    "plt.savefig(\"../../plots/decoded_gen_reaches.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## spike stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "device = \"cuda:0\"\n",
    "u = train_dataset.stim.permute(0,2,1).to(device)\n",
    "initial_states = Qzs_filt[:, 0, :]\n",
    "initial_states_mean = initial_states.mean(axis=0)\n",
    "initial_states_std = initial_states.std(axis=0)\n",
    "noise_scale = 1\n",
    "np.random.seed(1)\n",
    "torch.manual_seed(1)\n",
    "\n",
    "z_sim =np.zeros((u.shape[0],dur,dim_z))\n",
    "for ri in range(u.shape[0]):\n",
    "    input = u[ri][None, :, :]\n",
    "    # z0 = np.random.randn(dim_z) * initial_states_std + initial_states_mean\n",
    "    z0 = torch.tensor(initial_states[ri], dtype=torch.float, device=device)[None, :, None, None]\n",
    "    # z0 = torch.tensor(z0, dtype=torch.float, device=device)[None, :, None, None]\n",
    "    z = vae.prior.get_latent_time_series(dur,u=input,z0=z0,noise_scale=noise_scale).cpu().numpy()\n",
    "    z_sim[ri]=z.transpose(0, 2, 1, 3).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:0\"\n",
    "u = eval_dataset.stim.permute(0,2,1).to(device)\n",
    "initial_states = Qzs_filt[:, 0, :]\n",
    "initial_states_mean = initial_states.mean(axis=0)\n",
    "initial_states_std = initial_states.std(axis=0)\n",
    "noise_scale = 1\n",
    "np.random.seed(1)\n",
    "torch.manual_seed(1)\n",
    "\n",
    "z_sim =np.zeros((u.shape[0],dur,dim_z))\n",
    "for ri in range(u.shape[0]):\n",
    "    input = u[ri][None, :, :]\n",
    "    # z0 = np.random.randn(dim_z) * initial_states_std + initial_states_mean\n",
    "    z0 = torch.tensor(initial_states[ri], dtype=torch.float, device=device)[None, :, None, None]\n",
    "    # z0 = torch.tensor(z0, dtype=torch.float, device=device)[None, :, None, None]\n",
    "    z = vae.prior.get_latent_time_series(dur,u=input,z0=z0,noise_scale=noise_scale).cpu().numpy()\n",
    "    z_sim[ri]=z.transpose(0, 2, 1, 3).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_sim = np.empty((*z_sim.shape[:-1], 182))\n",
    "for i in range(z_sim.shape[0]):\n",
    "    with torch.no_grad():\n",
    "        x = vae.obs_rectify(\n",
    "            vae.prior.get_observation(\n",
    "                z=torch.tensor(z_sim[i], dtype=torch.float, device=device).unsqueeze(-1).unsqueeze(0).permute(0, 2, 1, 3)\n",
    "            )\n",
    "        ).cpu().numpy()\n",
    "        x_sim[i] = x.transpose(0, 2, 1, 3).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def coefficient_of_variation(data):\n",
    "    \"\"\"Calculate the coefficient of variation for each neuron.\"\"\"\n",
    "    means = np.mean(data, axis=0)\n",
    "    std_devs = np.std(data, axis=0)\n",
    "    return np.where(means == 0, 0, std_devs / means)  # Handle divide by zero by setting CV to 0\n",
    "\n",
    "def binned_spikes_to_times(binned_spikes, bin_size=0.02):\n",
    "    idxs = np.nonzero(binned_spikes)[0]\n",
    "    spike_times = np.repeat(idxs, binned_spikes[idxs].astype(int))\n",
    "    spike_times = spike_times.astype(float)\n",
    "    for idx in idxs:\n",
    "        if binned_spikes[idx] > 1:\n",
    "            spike_times[spike_times == idx] += np.arange(binned_spikes[idx]) / binned_spikes[idx]\n",
    "    spike_times *= bin_size\n",
    "    return spike_times\n",
    "\n",
    "def compute_isi_stats(spikes, bin_size=0.02):\n",
    "    n_trials, n_timesteps, n_neurons = spikes.shape\n",
    "    isi_means = []\n",
    "    isi_stds = []\n",
    "    isi_cvs = []\n",
    "    for neuron in range(n_neurons):\n",
    "        isis = np.concatenate([np.diff(binned_spikes_to_times(spikes[i,:,neuron], bin_size=bin_size)) for i in range(n_trials)])\n",
    "        isi_means.append(np.mean(isis))\n",
    "        isi_stds.append(np.std(isis))\n",
    "        isi_cvs.append(coefficient_of_variation(isis))\n",
    "    return np.array(isi_means), np.array(isi_stds), np.array(isi_cvs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ground truth overall stats\n",
    "spikes = train_dataset.data.detach().cpu().numpy()\n",
    "\n",
    "sr_mean_all = spikes.mean(axis=(0,1))\n",
    "sr_std_all = spikes.std(axis=(0,1))\n",
    "isi_mean_all, isi_std_all, isi_cv_all = compute_isi_stats(spikes)\n",
    "\n",
    "fig, axs = plt.subplots(1, 5, figsize=(5,1))\n",
    "axs[0].hist(sr_mean_all)\n",
    "axs[0].set_title(\"mean spike rate\")\n",
    "\n",
    "axs[1].hist(sr_std_all)\n",
    "axs[1].set_title(\"std spike rate\")\n",
    "\n",
    "axs[2].hist(isi_mean_all)\n",
    "axs[2].set_title(\"mean isi\")\n",
    "\n",
    "axs[3].hist(isi_std_all)\n",
    "axs[3].set_title(\"std isi\")\n",
    "\n",
    "axs[4].hist(isi_cv_all)\n",
    "axs[4].set_title(\"isi cv\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ground truth overall stats\n",
    "test_spikes = eval_dataset.data.detach().cpu().numpy()\n",
    "\n",
    "test_sr_mean_all = test_spikes.mean(axis=(0,1))\n",
    "test_sr_std_all = test_spikes.std(axis=(0,1))\n",
    "test_isi_mean_all, test_isi_std_all, test_isi_cv_all = compute_isi_stats(test_spikes)\n",
    "\n",
    "fig, axs = plt.subplots(1, 5, figsize=(5,1))\n",
    "axs[0].hist(test_sr_mean_all)\n",
    "axs[0].set_title(\"mean spike rate\")\n",
    "\n",
    "axs[1].hist(test_sr_std_all)\n",
    "axs[1].set_title(\"std spike rate\")\n",
    "\n",
    "axs[2].hist(test_isi_mean_all)\n",
    "axs[2].set_title(\"mean isi\")\n",
    "\n",
    "axs[3].hist(test_isi_std_all)\n",
    "axs[3].set_title(\"std isi\")\n",
    "\n",
    "axs[4].hist(test_isi_cv_all)\n",
    "axs[4].set_title(\"isi cv\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inferred spikes\n",
    "\n",
    "np.random.seed(0)\n",
    "inferred_spikes = np.random.poisson(x_sim)\n",
    "\n",
    "inf_sr_mean_all = inferred_spikes.mean(axis=(0,1))\n",
    "inf_sr_std_all = inferred_spikes.std(axis=(0,1))\n",
    "inf_isi_mean_all, inf_isi_std_all, inf_isi_cv_all = compute_isi_stats(inferred_spikes)\n",
    "\n",
    "fig, axs = plt.subplots(1, 5, figsize=(5,1))\n",
    "axs[0].hist(inf_sr_mean_all)\n",
    "axs[0].set_title(\"mean spike rate\")\n",
    "\n",
    "axs[1].hist(inf_sr_std_all)\n",
    "axs[1].set_title(\"std spike rate\")\n",
    "\n",
    "axs[2].hist(inf_isi_mean_all)\n",
    "axs[2].set_title(\"mean isi\")\n",
    "\n",
    "axs[3].hist(inf_isi_std_all)\n",
    "axs[3].set_title(\"std isi\")\n",
    "\n",
    "axs[4].hist(inf_isi_cv_all)\n",
    "axs[4].set_title(\"isi cv\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_vals = [sr_mean_all, sr_std_all, isi_mean_all, isi_std_all, isi_cv_all]\n",
    "inf_vals = [inf_sr_mean_all, inf_sr_std_all, inf_isi_mean_all, inf_isi_std_all, inf_isi_cv_all]\n",
    "labels = [\"mean spike rate\", \"std spike rate\", \"mean isi\", \"std isi\", \"isi cv\"]\n",
    "\n",
    "fig, axs = plt.subplots(1, 5, figsize=(7.5,1.8))\n",
    "for i, true, inf, label in zip(range(len(labels)), true_vals, inf_vals, labels):\n",
    "    axs[i].scatter(inf, true)\n",
    "    xlim = axs[i].get_xlim()\n",
    "    ylim = axs[i].get_ylim()\n",
    "    axs[i].plot(np.arange(-0.5, 2.0), np.arange(-0.5, 2.0), linestyle='--', color='black')\n",
    "    axs[i].set_xlim(0, max(xlim[1], ylim[1]))\n",
    "    axs[i].set_ylim(0, max(xlim[1], ylim[1]))\n",
    "    axs[i].set_title(label)\n",
    "    axs[i].set_xlabel(\"inferred\")\n",
    "axs[0].set_ylabel(\"true\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_vals = [sr_mean_all, sr_std_all, isi_mean_all, isi_std_all, isi_cv_all]\n",
    "inf_vals = [test_sr_mean_all, test_sr_std_all, test_isi_mean_all, test_isi_std_all, test_isi_cv_all]\n",
    "labels = [\"mean spike rate\", \"std spike rate\", \"mean isi\", \"std isi\", \"isi cv\"]\n",
    "\n",
    "fig, axs = plt.subplots(1, 5, figsize=(7.5,1.8))\n",
    "for i, true, inf, label in zip(range(len(labels)), true_vals, inf_vals, labels):\n",
    "    axs[i].scatter(inf, true)\n",
    "    xlim = axs[i].get_xlim()\n",
    "    ylim = axs[i].get_ylim()\n",
    "    axs[i].plot(np.arange(-0.5, 2.0), np.arange(-0.5, 2.0), linestyle='--', color='black')\n",
    "    axs[i].set_xlim(0, max(xlim[1], ylim[1]))\n",
    "    axs[i].set_ylim(0, max(xlim[1], ylim[1]))\n",
    "    axs[i].set_title(label)\n",
    "    axs[i].set_xlabel(\"inferred\")\n",
    "axs[0].set_ylabel(\"true\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# conditional true spike stats\n",
    "\n",
    "spikes = train_dataset.data.detach().cpu().numpy()\n",
    "target_pos = train_dataset.stim[:, 0, :].detach().cpu().numpy()\n",
    "angles = np.arctan2(target_pos[:, 1], target_pos[:, 0])\n",
    "angles = angles / (2 * np.pi) + 0.5\n",
    "angles = (np.round(angles * 8) % 8) / 8\n",
    "\n",
    "sr_mean_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "sr_std_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "isi_mean_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "isi_std_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "isi_cv_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "for i, angle in enumerate(np.sort(np.unique(angles))):\n",
    "    mask = (angles == angle)\n",
    "    sr_mean = spikes[mask].mean(axis=(0,1))\n",
    "    sr_std = spikes[mask].std(axis=(0,1))\n",
    "    isi_mean, isi_std, isi_cv = compute_isi_stats(spikes[mask])\n",
    "\n",
    "    sr_mean_cond[i] = sr_mean\n",
    "    sr_std_cond[i] = sr_std\n",
    "    isi_mean_cond[i] = isi_mean\n",
    "    isi_std_cond[i] = isi_std\n",
    "    isi_cv_cond[i] = isi_cv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# conditional test true spike stats\n",
    "\n",
    "test_spikes = eval_dataset.data.detach().cpu().numpy()\n",
    "test_target_pos = eval_dataset.stim[:, 0, :].detach().cpu().numpy()\n",
    "test_angles = np.arctan2(test_target_pos[:, 1], test_target_pos[:, 0])\n",
    "test_angles = test_angles / (2 * np.pi) + 0.5\n",
    "test_angles = (np.round(test_angles * 8) % 8) / 8\n",
    "\n",
    "test_sr_mean_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[-1]))\n",
    "test_sr_std_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[-1]))\n",
    "test_isi_mean_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[-1]))\n",
    "test_isi_std_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[-1]))\n",
    "test_isi_cv_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[-1]))\n",
    "for i, angle in enumerate(np.sort(np.unique(test_angles))):\n",
    "    mask = (test_angles == angle)\n",
    "    test_sr_mean = test_spikes[mask].mean(axis=(0,1))\n",
    "    test_sr_std = test_spikes[mask].std(axis=(0,1))\n",
    "    test_isi_mean, test_isi_std, test_isi_cv = compute_isi_stats(test_spikes[mask])\n",
    "\n",
    "    test_sr_mean_cond[i] = test_sr_mean\n",
    "    test_sr_std_cond[i] = test_sr_std\n",
    "    test_isi_mean_cond[i] = test_isi_mean\n",
    "    test_isi_std_cond[i] = test_isi_std\n",
    "    test_isi_cv_cond[i] = test_isi_cv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 5, figsize=(7.5,1.5))\n",
    "\n",
    "axs[0].imshow(np.corrcoef(sr_mean_cond), vmin=0, vmax=1)\n",
    "axs[0].set_title(\"mean spike rate\")\n",
    "\n",
    "axs[1].imshow(np.corrcoef(sr_std_cond), vmin=0, vmax=1)\n",
    "axs[1].set_title(\"std spike rate\")\n",
    "\n",
    "mask = np.any(np.isnan(isi_mean_cond), axis=0)\n",
    "axs[2].imshow(np.corrcoef(isi_mean_cond[:, ~mask]), vmin=0, vmax=1)\n",
    "axs[2].set_title(\"mean isi\")\n",
    "\n",
    "mask = np.any(np.isnan(isi_std_cond), axis=0)\n",
    "axs[3].imshow(np.corrcoef(isi_std_cond[:, ~mask]), vmin=0, vmax=1)\n",
    "axs[3].set_title(\"std isi\")\n",
    "\n",
    "mask = np.any(np.isnan(isi_cv_cond), axis=0)\n",
    "axs[4].imshow(np.corrcoef(isi_cv_cond[:, ~mask]), vmin=0, vmax=1)\n",
    "axs[4].set_title(\"isi cv\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inf_sr_mean_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "inf_sr_std_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "inf_isi_mean_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "inf_isi_std_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "inf_isi_cv_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "for i, angle in enumerate(np.sort(np.unique(angles))):\n",
    "    mask = (test_angles == angle)\n",
    "    inf_sr_mean = inferred_spikes[mask].mean(axis=(0,1))\n",
    "    inf_sr_std = inferred_spikes[mask].std(axis=(0,1))\n",
    "    inf_isi_mean, inf_isi_std, inf_isi_cv = compute_isi_stats(inferred_spikes[mask])\n",
    "\n",
    "    inf_sr_mean_cond[i] = inf_sr_mean\n",
    "    inf_sr_std_cond[i] = inf_sr_std\n",
    "    inf_isi_mean_cond[i] = inf_isi_mean\n",
    "    inf_isi_std_cond[i] = inf_isi_std\n",
    "    inf_isi_cv_cond[i] = inf_isi_cv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 5, figsize=(7.5,1.5))\n",
    "\n",
    "axs[0].imshow(np.corrcoef(inf_sr_mean_cond), vmin=0, vmax=1)\n",
    "axs[0].set_title(\"mean spike rate\")\n",
    "\n",
    "axs[1].imshow(np.corrcoef(inf_sr_std_cond), vmin=0, vmax=1)\n",
    "axs[1].set_title(\"std spike rate\")\n",
    "\n",
    "mask = np.any(np.isnan(inf_isi_mean_cond), axis=0)\n",
    "axs[2].imshow(np.corrcoef(inf_isi_mean_cond[:, ~mask]), vmin=0, vmax=1)\n",
    "axs[2].set_title(\"mean isi\")\n",
    "\n",
    "mask = np.any(np.isnan(inf_isi_std_cond), axis=0)\n",
    "axs[3].imshow(np.corrcoef(inf_isi_std_cond[:, ~mask]), vmin=0, vmax=1)\n",
    "axs[3].set_title(\"std isi\")\n",
    "\n",
    "mask = np.any(np.isnan(inf_isi_cv_cond), axis=0)\n",
    "axs[4].imshow(np.corrcoef(inf_isi_cv_cond[:, ~mask]), vmin=0, vmax=1)\n",
    "axs[4].set_title(\"isi cv\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rmse(pred, target):\n",
    "    return np.sqrt(np.mean(np.square(pred - target)))\n",
    "\n",
    "def nanrmse(pred, target, warn_extra_nan=False):\n",
    "    mask = np.isnan(target)\n",
    "    if np.any(np.isnan(pred[~mask])) and warn_extra_nan:\n",
    "        print(\"nans found in predictions for valid tragets\")\n",
    "    mask = np.logical_or(mask, np.isnan(pred))\n",
    "    return np.sqrt(np.mean(np.square(pred[~mask] - target[~mask])))\n",
    "\n",
    "sr_mean_all_rmse = rmse(inf_sr_mean_all, sr_mean_all)\n",
    "sr_std_all_rmse = rmse(inf_sr_std_all, sr_std_all)\n",
    "isi_mean_all_rmse = nanrmse(inf_isi_mean_all, isi_mean_all)\n",
    "isi_std_all_rmse = nanrmse(inf_isi_std_all, isi_std_all)\n",
    "\n",
    "print(f\"{sr_mean_all_rmse=}\\n{sr_std_all_rmse=}\\n{isi_mean_all_rmse=}\\n{isi_std_all_rmse=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_conds = sr_mean_cond.shape[0]\n",
    "\n",
    "sr_mean_cond_rmse = []\n",
    "sr_std_cond_rmse = []\n",
    "isi_mean_cond_rmse = []\n",
    "isi_std_cond_rmse = []\n",
    "for cond in range(n_conds):\n",
    "    sr_mean_cond_rmse.append(rmse(inf_sr_mean_cond[cond], sr_mean_cond[cond]))\n",
    "    sr_std_cond_rmse.append(rmse(inf_sr_std_cond[cond], sr_std_cond[cond]))\n",
    "    isi_mean_cond_rmse.append(nanrmse(inf_isi_mean_cond[cond], isi_mean_cond[cond]))\n",
    "    isi_std_cond_rmse.append(nanrmse(inf_isi_std_cond[cond], isi_std_cond[cond]))\n",
    "\n",
    "sr_mean_cond_rmse = np.mean(sr_mean_cond_rmse)\n",
    "sr_std_cond_rmse = np.mean(sr_std_cond_rmse)\n",
    "isi_mean_cond_rmse = np.mean(isi_mean_cond_rmse)\n",
    "isi_std_cond_rmse = np.mean(isi_std_cond_rmse)\n",
    "\n",
    "print(f\"{sr_mean_cond_rmse=}\\n{sr_std_cond_rmse=}\\n{isi_mean_cond_rmse=}\\n{isi_std_cond_rmse=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_conds = sr_mean_cond.shape[0]\n",
    "\n",
    "bl_sr_mean_cond_rmse = []\n",
    "bl_sr_std_cond_rmse = []\n",
    "bl_isi_mean_cond_rmse = []\n",
    "bl_isi_std_cond_rmse = []\n",
    "for cond in range(n_conds):\n",
    "    bl_sr_mean_cond_rmse.append(rmse(sr_mean_all, sr_mean_cond[cond]))\n",
    "    bl_sr_std_cond_rmse.append(rmse(sr_std_all, sr_std_cond[cond]))\n",
    "    bl_isi_mean_cond_rmse.append(nanrmse(isi_mean_all, isi_mean_cond[cond]))\n",
    "    bl_isi_std_cond_rmse.append(nanrmse(isi_std_all, isi_std_cond[cond]))\n",
    "\n",
    "bl_sr_mean_cond_rmse = np.mean(bl_sr_mean_cond_rmse)\n",
    "bl_sr_std_cond_rmse = np.mean(bl_sr_std_cond_rmse)\n",
    "bl_isi_mean_cond_rmse = np.mean(bl_isi_mean_cond_rmse)\n",
    "bl_isi_std_cond_rmse = np.mean(bl_isi_std_cond_rmse)\n",
    "\n",
    "print(f\"{bl_sr_mean_cond_rmse=}\\n{bl_sr_std_cond_rmse=}\\n{bl_isi_mean_cond_rmse=}\\n{bl_isi_std_cond_rmse=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(2.4, 1.4))\n",
    "n_cond = sr_mean_cond.shape[0]\n",
    "\n",
    "axs[0].imshow(np.corrcoef(sr_mean_cond, test_sr_mean_cond)[n_cond:, :n_cond], vmin=0.0, vmax=1, cmap=\"coolwarm\")\n",
    "axs[0].set_title(\"true\")\n",
    "axs[0].set_xticks([])\n",
    "axs[0].set_yticks([])\n",
    "axs[0].set_xlabel(\"condition\")\n",
    "axs[0].set_ylabel(\"condition\")\n",
    "\n",
    "axs[1].imshow(np.corrcoef(inf_sr_mean_cond, test_sr_mean_cond)[n_cond:, :n_cond], vmin=0.0, vmax=1, cmap=\"coolwarm\")\n",
    "axs[1].set_title(\"model\")\n",
    "axs[1].set_xticks([])\n",
    "axs[1].set_yticks([])\n",
    "axs[1].set_xlabel(\"condition\")\n",
    "plt.suptitle(\"     Spike rate correlation\", y=0.95)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../../plots/cond_spike_corr.svg\")\n",
    "plt.savefig(\"../../plots/cond_spike_corr.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(2.0, 1.0))\n",
    "n_cond = sr_mean_cond.shape[0]\n",
    "\n",
    "vmin = 0.0\n",
    "vmax = 0.3\n",
    "\n",
    "axs[0].imshow(1 - np.corrcoef(test_sr_mean_cond), vmin=vmin, vmax=vmax, cmap=\"coolwarm\")\n",
    "axs[0].set_title(\"true\")\n",
    "axs[0].set_xticks([])\n",
    "axs[0].set_yticks([])\n",
    "axs[0].set_xlabel(\"condition\")\n",
    "axs[0].set_ylabel(\"condition\")\n",
    "\n",
    "axs[1].imshow(1 - np.corrcoef(inf_sr_mean_cond), vmin=vmin, vmax=vmax, cmap=\"coolwarm\")\n",
    "axs[1].set_title(\"model\")\n",
    "axs[1].set_xticks([])\n",
    "axs[1].set_yticks([])\n",
    "axs[1].set_xlabel(\"condition\")\n",
    "# plt.suptitle(\"     mean firing rate dissimilarity\", y=0.95)\n",
    "# plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../../plots/cond_spike_corr.svg\")\n",
    "plt.savefig(\"../../plots/cond_spike_corr.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(figsize=(0.25 / 3, 3 / 3))\n",
    "cmap = mpl.cm.coolwarm\n",
    "norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)\n",
    "\n",
    "cb1 = mpl.colorbar.ColorbarBase(axs, cmap=cmap,\n",
    "                                norm=norm,\n",
    "                                orientation='vertical')\n",
    "cb1.set_ticks([0.0, 0.15, 0.3])\n",
    "cb1.set_label(\"corr. distance\")\n",
    "\n",
    "plt.savefig(\"../../plots/fr_colorbar.svg\")\n",
    "plt.savefig(\"../../plots/fr_colorbar.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=(1.0, 1.0))\n",
    "n_cond = isi_mean_cond.shape[0]\n",
    "\n",
    "vmin = 0.0\n",
    "vmax = 0.3\n",
    "\n",
    "axs.imshow(1 - np.corrcoef(test_sr_mean_cond), vmin=vmin, vmax=vmax, cmap=\"coolwarm\")\n",
    "axs.set_title(\"true\")\n",
    "axs.set_xticks([])\n",
    "axs.set_yticks([])\n",
    "axs.set_xlabel(\"condition\")\n",
    "axs.set_ylabel(\"condition\")\n",
    "# plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../../plots/cond_test_spike_corr.svg\")\n",
    "plt.savefig(\"../../plots/cond_test_spike_corr.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(2.0, 1.0))\n",
    "n_cond = isi_mean_cond.shape[0]\n",
    "\n",
    "vmin = 0.0\n",
    "vmax = 1.0\n",
    "\n",
    "nan_mask = np.logical_or(\n",
    "    np.any(np.isnan(isi_mean_cond), axis=0),\n",
    "    np.any(np.isnan(inf_isi_mean_cond), axis=0),\n",
    ")\n",
    "nan_mask = np.logical_or(\n",
    "    nan_mask,\n",
    "    np.any(np.isnan(test_isi_mean_cond), axis=0),\n",
    ")\n",
    "\n",
    "axs[0].imshow(1 - np.corrcoef(test_isi_mean_cond[:,~nan_mask]), vmin=vmin, vmax=vmax, cmap=\"coolwarm\")\n",
    "axs[0].set_title(\"true\")\n",
    "axs[0].set_xticks([])\n",
    "axs[0].set_yticks([])\n",
    "axs[0].set_xlabel(\"condition\")\n",
    "axs[0].set_ylabel(\"condition\")\n",
    "\n",
    "axs[1].imshow(1 - np.corrcoef(inf_isi_mean_cond[:,~nan_mask]), vmin=vmin, vmax=vmax, cmap=\"coolwarm\")\n",
    "axs[1].set_title(\"model\")\n",
    "axs[1].set_xticks([])\n",
    "axs[1].set_yticks([])\n",
    "axs[1].set_xlabel(\"condition\")\n",
    "# plt.suptitle(\"     mean ISI dissimilarity\", y=0.95)\n",
    "# plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../../plots/cond_isi_corr.svg\")\n",
    "plt.savefig(\"../../plots/cond_isi_corr.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(figsize=(0.25 / 3, 3 / 3))\n",
    "cmap = mpl.cm.coolwarm\n",
    "norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)\n",
    "\n",
    "cb1 = mpl.colorbar.ColorbarBase(axs, cmap=cmap,\n",
    "                                norm=norm,\n",
    "                                orientation='vertical')\n",
    "cb1.set_ticks([0.0, 0.5, 1.0])\n",
    "cb1.set_label(\"corr. distance\")\n",
    "\n",
    "plt.savefig(\"../../plots/isi_colorbar.svg\")\n",
    "plt.savefig(\"../../plots/isi_colorbar.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# rmse distance matrix\n",
    "\n",
    "def distance_matrix(a, b=None):\n",
    "    if b is None:\n",
    "        b = a\n",
    "    assert len(a.shape) == 2\n",
    "    assert len(b.shape) == 2\n",
    "    assert a.shape[1] == b.shape[1]\n",
    "    nan_mask = np.logical_or(\n",
    "        np.any(np.isnan(a), axis=0),\n",
    "        np.any(np.isnan(b), axis=0),\n",
    "    )\n",
    "    a = a[:, ~nan_mask]\n",
    "    b = b[:, ~nan_mask]\n",
    "    # could totally be vectorized but don't have the brain for it now\n",
    "    dist_mat = a[:, None] - b[None, :]\n",
    "    dist_mat = np.sqrt(np.mean(np.square(dist_mat), axis=-1))\n",
    "    return dist_mat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(2.0, 1.0))\n",
    "n_cond = sr_mean_cond.shape[0]\n",
    "\n",
    "train_test_mat = distance_matrix(sr_mean_cond, sr_mean_cond)\n",
    "model_test_mat = distance_matrix(inf_sr_mean_cond, inf_sr_mean_cond)\n",
    "vmin = 0.0\n",
    "# vmax = np.ceil(max(np.max(train_test_mat), np.max(model_test_mat)) * 100) / 100\n",
    "vmax = 0.08\n",
    "\n",
    "axs[0].imshow(train_test_mat, vmin=vmin, vmax=vmax, cmap=\"coolwarm\")\n",
    "axs[0].set_title(\"true\")\n",
    "axs[0].set_xticks([])\n",
    "axs[0].set_yticks([])\n",
    "axs[0].set_xlabel(\"condition\")\n",
    "axs[0].set_ylabel(\"condition\")\n",
    "\n",
    "axs[1].imshow(model_test_mat, vmin=vmin, vmax=vmax, cmap=\"coolwarm\")\n",
    "axs[1].set_title(\"model\")\n",
    "axs[1].set_xticks([])\n",
    "axs[1].set_yticks([])\n",
    "axs[1].set_xlabel(\"condition\")\n",
    "# plt.suptitle(\"   mean firing rate dissimilarity\", y=0.95)\n",
    "# plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../../plots/cond_fr_dm.svg\")\n",
    "plt.savefig(\"../../plots/cond_fr_dm.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(figsize=(0.25 / 3, 3 / 3))\n",
    "cmap = mpl.cm.coolwarm\n",
    "norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)\n",
    "\n",
    "cb1 = mpl.colorbar.ColorbarBase(axs, cmap=cmap,\n",
    "                                norm=norm,\n",
    "                                orientation='vertical')\n",
    "cb1.set_ticks([vmin, (vmin + vmax) / 2, vmax])\n",
    "cb1.set_label(\"RMSE\")\n",
    "\n",
    "plt.savefig(\"../../plots/fr_colorbar.svg\")\n",
    "plt.savefig(\"../../plots/fr_colorbar.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(2.0, 1.0))\n",
    "n_cond = sr_mean_cond.shape[0]\n",
    "\n",
    "nan_mask = np.logical_or(\n",
    "    np.any(np.isnan(isi_mean_cond), axis=0),\n",
    "    np.any(np.isnan(inf_isi_mean_cond), axis=0),\n",
    ")\n",
    "nan_mask = np.logical_or(\n",
    "    nan_mask,\n",
    "    np.any(np.isnan(test_isi_mean_cond), axis=0),\n",
    ")\n",
    "\n",
    "train_test_mat = distance_matrix(isi_mean_cond[:, ~nan_mask], isi_mean_cond[:, ~nan_mask])\n",
    "model_test_mat = distance_matrix(inf_isi_mean_cond[:, ~nan_mask], inf_isi_mean_cond[:, ~nan_mask])\n",
    "vmin = 0.0\n",
    "vmax = np.ceil(max(np.max(train_test_mat), np.max(model_test_mat)) * 100) / 100\n",
    "# vmax = 0.08\n",
    "\n",
    "axs[0].imshow(train_test_mat, vmin=0.0, cmap=\"coolwarm\")\n",
    "axs[0].set_title(\"true\")\n",
    "axs[0].set_xticks([])\n",
    "axs[0].set_yticks([])\n",
    "axs[0].set_xlabel(\"condition\")\n",
    "axs[0].set_ylabel(\"condition\")\n",
    "\n",
    "axs[1].imshow(model_test_mat, vmin=0.0, cmap=\"coolwarm\")\n",
    "axs[1].set_title(\"model\")\n",
    "axs[1].set_xticks([])\n",
    "axs[1].set_yticks([])\n",
    "axs[1].set_xlabel(\"condition\")\n",
    "# plt.suptitle(\"  ISI mean dissimilarity\", y=0.95)\n",
    "# plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../../plots/cond_isimean_dm.svg\")\n",
    "plt.savefig(\"../../plots/cond_isimean_dm.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(figsize=(0.25 / 3, 3 / 3))\n",
    "cmap = mpl.cm.coolwarm\n",
    "norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)\n",
    "\n",
    "cb1 = mpl.colorbar.ColorbarBase(axs, cmap=cmap,\n",
    "                                norm=norm,\n",
    "                                orientation='vertical')\n",
    "cb1.set_ticks([vmin, (vmin + vmax) / 2, vmax])\n",
    "cb1.set_label(\"RMSE\")\n",
    "\n",
    "plt.savefig(\"../../plots/isimean_colorbar.svg\")\n",
    "plt.savefig(\"../../plots/isimean_colorbar.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(2.4, 1.4))\n",
    "n_cond = sr_mean_cond.shape[0]\n",
    "\n",
    "nan_mask = np.logical_or(\n",
    "    np.any(np.isnan(isi_cv_cond), axis=0),\n",
    "    np.any(np.isnan(test_isi_cv_cond), axis=0),\n",
    ")\n",
    "nan_mask = np.logical_or(\n",
    "    nan_mask,\n",
    "    np.any(np.isnan(inf_isi_cv_cond), axis=0),\n",
    ")\n",
    "\n",
    "train_test_mat = distance_matrix(test_isi_cv_cond[:, ~nan_mask], test_isi_cv_cond[:, ~nan_mask])\n",
    "model_test_mat = distance_matrix(inf_isi_cv_cond[:, ~nan_mask], inf_isi_cv_cond[:, ~nan_mask])\n",
    "vmin = 0.0\n",
    "vmax = 0.6 # np.ceil(max(np.max(train_test_mat), np.max(model_test_mat)) * 100) / 100\n",
    "\n",
    "axs[0].imshow(train_test_mat, vmin=0.0, cmap=\"coolwarm\")\n",
    "# axs[0].set_title(\"true\")\n",
    "axs[0].set_xticks([])\n",
    "axs[0].set_yticks([])\n",
    "axs[0].set_xlabel(\"train\")\n",
    "axs[0].set_ylabel(\"test\")\n",
    "\n",
    "axs[1].imshow(model_test_mat, vmin=0.0, cmap=\"coolwarm\")\n",
    "# axs[1].set_title(\"model\")\n",
    "axs[1].set_xticks([])\n",
    "axs[1].set_yticks([])\n",
    "axs[1].set_xlabel(\"model\")\n",
    "plt.suptitle(\"  ISI CV dissimilarity\", y=0.95)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../../plots/cond_isicv_dm.svg\")\n",
    "plt.savefig(\"../../plots/cond_isicv_dm.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(figsize=(0.25 / 3, 3 / 3))\n",
    "cmap = mpl.cm.coolwarm\n",
    "norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)\n",
    "\n",
    "cb1 = mpl.colorbar.ColorbarBase(axs, cmap=cmap,\n",
    "                                norm=norm,\n",
    "                                orientation='vertical')\n",
    "cb1.set_ticks([0.0, 0.2, 0.4])\n",
    "cb1.set_label(\"RMSE\")\n",
    "\n",
    "plt.savefig(\"../../plots/isicv_colorbar.svg\")\n",
    "plt.savefig(\"../../plots/isicv_colorbar.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(2.4, 1.4))\n",
    "n_cond = sr_mean_cond.shape[0]\n",
    "\n",
    "axs[0].imshow(distance_matrix(isi_cv_cond, isi_cv_cond), vmin=0.0, cmap=\"coolwarm\")\n",
    "axs[0].set_title(\"true\")\n",
    "axs[0].set_xticks([])\n",
    "axs[0].set_yticks([])\n",
    "axs[0].set_xlabel(\"condition\")\n",
    "axs[0].set_ylabel(\"condition\")\n",
    "\n",
    "axs[1].imshow(distance_matrix(inf_isi_cv_cond, inf_isi_cv_cond), vmin=0.0, cmap=\"coolwarm\")\n",
    "axs[1].set_title(\"model\")\n",
    "axs[1].set_xticks([])\n",
    "axs[1].set_yticks([])\n",
    "axs[1].set_xlabel(\"condition\")\n",
    "plt.suptitle(\"  ISI CV dissimilarity\", y=0.95)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../../plots/cond_spike_corr.svg\")\n",
    "plt.savefig(\"../../plots/cond_spike_corr.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(2.4, 1.4))\n",
    "n_cond = sr_mean_cond.shape[0]\n",
    "\n",
    "axs[0].imshow(distance_matrix(isi_mean_cond, isi_mean_cond), vmin=0.0, cmap=\"coolwarm\")\n",
    "axs[0].set_title(\"true\")\n",
    "axs[0].set_xticks([])\n",
    "axs[0].set_yticks([])\n",
    "axs[0].set_xlabel(\"condition\")\n",
    "axs[0].set_ylabel(\"condition\")\n",
    "\n",
    "axs[1].imshow(distance_matrix(inf_isi_mean_cond, inf_isi_mean_cond), vmin=0.0, cmap=\"coolwarm\")\n",
    "axs[1].set_title(\"model\")\n",
    "axs[1].set_xticks([])\n",
    "axs[1].set_yticks([])\n",
    "axs[1].set_xlabel(\"condition\")\n",
    "plt.suptitle(\"  ISI CV dissimilarity\", y=0.95)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../../plots/cond_spike_corr.svg\")\n",
    "plt.savefig(\"../../plots/cond_spike_corr.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## todo\n",
    "\n",
    "- spike stats of conditional generation, not inferred\n",
    "- generation of unseen reach (straight down)\n",
    "- projection onto interpretable axes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## weird stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_0_range = np.arange(-10, 10, 0.1)\n",
    "f_z_0 = []\n",
    "z_1_5 = torch.zeros(4, dtype=torch.float)\n",
    "for z_0 in z_0_range:\n",
    "    z = torch.cat([torch.tensor([z_0], dtype=torch.float), z_1_5]).to(\"cuda:0\")\n",
    "    f_z_0.append(vae.prior.transition.forward(z=z[None, :, None, None], grad=False).squeeze()[0].item())\n",
    "\n",
    "plt.plot(z_0_range, f_z_0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_0_range = np.arange(-10, 10, 0.1)\n",
    "\n",
    "for i in range(20):\n",
    "    f_z_0 = []\n",
    "    z_1_5 = torch.randn(4, dtype=torch.float) * 3\n",
    "    for z_0 in z_0_range:\n",
    "        z = torch.cat([torch.tensor([z_0], dtype=torch.float), z_1_5]).to(\"cuda:0\")\n",
    "        f_z_0.append(vae_orth.prior.transition.forward(z=z[None, :, None, None], grad=False).squeeze()[0].item())\n",
    "\n",
    "    plt.plot(z_0_range, f_z_0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_0_range = np.arange(-10, 10, 0.1)\n",
    "\n",
    "for i in range(20):\n",
    "    f_z_1 = []\n",
    "    z_1_5 = torch.randn(4, dtype=torch.float) * 3\n",
    "    for z_0 in z_0_range:\n",
    "        z = torch.cat([torch.tensor([z_0], dtype=torch.float), z_1_5]).to(\"cuda:0\")\n",
    "        f_z_1.append(vae_orth.prior.transition.forward(z=z[None, :, None, None], grad=False).squeeze()[1].item())\n",
    "\n",
    "    plt.plot(z_0_range, f_z_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "virnn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
