{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## set up imports etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"4\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "sys.path.append('../')\n",
    "\n",
    "from rnn.vae import VAE\n",
    "from rnn.train import train_VAE\n",
    "from rnn.saving import save_model, load_model\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import h5py\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "from pathlib import Path\n",
    "\n",
    "mpl.rc_file(\"matplotlibrc\")\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load model and dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_ROOT = Path(\"../../\").absolute() / \"data\" / \"processed\"\n",
    "RUN_ROOT = Path(\"../../\").absolute() / \"runs\"\n",
    "CHKPT_DIR = Path(\"../../\").absolute() / \"checkpoints\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chkpt = \"maze_pos_nc_5d/\"\n",
    "chkpt_path = CHKPT_DIR / chkpt\n",
    "\n",
    "model_save_name = sorted(chkpt_path.glob(\"*.pkl\"))[0].stem\n",
    "\n",
    "# load model\n",
    "suffixes = [\n",
    "    \"_state_dict_enc\", \n",
    "    \"_state_dict_prior\", \n",
    "    \"_task_params\", \n",
    "    \"_training_params\", \n",
    "    \"_vae_params\", \n",
    "]\n",
    "for suffix in suffixes:\n",
    "    model_save_name = model_save_name.replace(suffix, \"\")\n",
    "model_save_name = str(chkpt_path / model_save_name)\n",
    "vae, vae_params, task_params, training_params = load_model(model_save_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_dataset(\n",
    "    name: str,\n",
    "    phase: str = \"val\",\n",
    "    bin_size: int = 5,\n",
    "    t_forward: int = 0,\n",
    "    input_field: str = None,\n",
    "):\n",
    "    data_path = DATA_ROOT / name / phase\n",
    "    train_inputs = eval_inputs = None\n",
    "    normalization = None\n",
    "    with h5py.File(data_path / f\"train_input_{bin_size}ms.h5\", \"r\") as h5f:\n",
    "        if t_forward > 0:\n",
    "            train_data = np.concatenate([\n",
    "                np.concatenate([\n",
    "                    h5f['train_spikes_heldin'][()],\n",
    "                    h5f['train_spikes_heldin_forward'][()],\n",
    "                ], axis=1),\n",
    "                np.concatenate([\n",
    "                    h5f['train_spikes_heldout'][()],\n",
    "                    h5f['train_spikes_heldout_forward'][()],\n",
    "                ], axis=1),\n",
    "            ], axis=2)\n",
    "        else:\n",
    "            train_data = np.concatenate([h5f['train_spikes_heldin'][()], h5f['train_spikes_heldout'][()]], axis=2)\n",
    "        if input_field is not None:\n",
    "            train_inputs = h5f[f\"train_{input_field}\"][()]\n",
    "            if len(train_inputs.shape) == 1:\n",
    "                train_inputs = train_inputs[:, None]\n",
    "            if len(train_inputs.shape) == 2:\n",
    "                train_inputs = np.tile(train_inputs[:, None, :], reps=(1, train_data.shape[1], 1))\n",
    "            normalization = np.max(np.abs(train_inputs), axis=(0,1), keepdims=True)\n",
    "            train_inputs = train_inputs / normalization\n",
    "    with h5py.File(data_path / f\"eval_input_{bin_size}ms.h5\", \"r\") as h5f:\n",
    "        assert \"eval_spikes_heldout\" in h5f.keys()\n",
    "        eval_data = np.concatenate([h5f['eval_spikes_heldin'][()], h5f['eval_spikes_heldout'][()]], axis=2)\n",
    "        if input_field is not None:\n",
    "            eval_inputs = h5f[f\"eval_{input_field}\"][()]\n",
    "            if len(eval_inputs.shape) == 1:\n",
    "                eval_inputs = eval_inputs[:, None]\n",
    "            if len(eval_inputs.shape) == 2:\n",
    "                eval_inputs = np.tile(eval_inputs[:, None, :], reps=(1, eval_data.shape[1], 1))\n",
    "            assert normalization is not None\n",
    "            eval_inputs = eval_inputs / normalization\n",
    "    return (\n",
    "        torch.tensor(train_data, dtype=torch.float), \n",
    "        torch.tensor(eval_data, dtype=torch.float),\n",
    "        torch.tensor(train_inputs, dtype=torch.float) if train_inputs is not None else train_inputs,\n",
    "        torch.tensor(eval_inputs, dtype=torch.float) if eval_inputs is not None else eval_inputs,        \n",
    "    )\n",
    "\n",
    "\n",
    "class NLBDataset(Dataset):\n",
    "    def __init__(self, data, params, inputs=None):\n",
    "        self.data = data\n",
    "        self.data_eval = data # Not used in this setup\n",
    "        self.trial_dur = data.shape[0]\n",
    "        self.task_params = params\n",
    "        self.stim = (\n",
    "            inputs.float().to(self.data.device)\n",
    "            if inputs is not None else \n",
    "            torch.zeros(data.shape[0], data.shape[1], 0, device=self.data.device, dtype=torch.float)\n",
    "        )\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        return self.data[idx].T, self.stim[idx].T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data, eval_data, train_inputs, eval_inputs = load_dataset(\n",
    "    name=\"mc_maze_input\",\n",
    "    phase=\"pos\",\n",
    "    bin_size=20,\n",
    "    t_forward=0,\n",
    "    input_field=\"input\"\n",
    ")\n",
    "train_dataset = NLBDataset(train_data, dict(name=\"mc_maze_input\"), inputs=train_inputs)\n",
    "eval_dataset = NLBDataset(eval_data, dict(name=\"mc_maze_input\"), inputs=eval_inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## generate latents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_z = vae.vae_params[\"dim_z\"]\n",
    "n_inp = vae.vae_params[\"dim_u\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.to_device(\"cpu\")\n",
    "\n",
    "noise =0\n",
    "dur = 35\n",
    "t_on = 0\n",
    "t_of = 12\n",
    "z0=torch.randn(1,dim_z,1).cuda()*.1\n",
    "u = torch.zeros(1,n_inp,dur,device='cpu')\n",
    "\n",
    "u[0,1,t_on:t_of]=-1\n",
    "u[0,0,t_on:t_of]=1\n",
    "\n",
    "with torch.no_grad():\n",
    "    for _ in range(1):\n",
    "        z = vae.prior.get_latent_time_series(dur,z0=z0,u=u,noise_scale=noise).cpu().numpy()\n",
    "        plt.plot(z[0,:,:,0].T,color='black',alpha=.6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "def orthogonalise_network(vae):\n",
    "    with torch.no_grad():\n",
    "        m_or = vae.prior.transition.m #20,2\n",
    "        n_or = vae.prior.transition.n\n",
    "        J = m_or@n_or\n",
    "        u,s,v = torch.linalg.svd(J)\n",
    "        projection_matrix = u[:,:vae.dim_z].T@m_or\n",
    "        proj_chol = projection_matrix@torch.diag(vae.prior.std_embed_z(vae.prior.R_z))\n",
    "        m_new = u[:,:vae.dim_z]\n",
    "        n_new = (v[:vae.dim_z].T * s[:vae.dim_z]).T\n",
    "        vae.prior.transition.m.copy_(m_new)\n",
    "        vae.prior.transition.n.copy_(n_new)\n",
    "        vae.prior.chol_cov_embed = lambda x: torch.tril(x)\n",
    "        vae.prior.R_z=torch.nn.Parameter(torch.linalg.cholesky(proj_chol@proj_chol.T))\n",
    "        vae.prior.params['scalar_noise_z']=\"Cov\"\n",
    "    return vae\n",
    "\n",
    "vae_orth = copy.deepcopy(vae)\n",
    "vae_orth = orthogonalise_network(vae_orth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    m_or = vae.prior.transition.m #20,2\n",
    "    n_or = vae.prior.transition.n\n",
    "    J = m_or@n_or\n",
    "    u,s,v = torch.linalg.svd(J)\n",
    "    projection_matrix = u[:,:vae.dim_z].T@m_or\n",
    "projection_matrix = projection_matrix.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "noise =0\n",
    "dur = 35\n",
    "t_on = 0\n",
    "t_of = 12\n",
    "z0=torch.randn(1,dim_z,1).cuda()*.1\n",
    "u = torch.zeros(1,n_inp,dur,device='cpu')\n",
    "\n",
    "u[0,1,t_on:t_of]=-1\n",
    "u[0,0,t_on:t_of]=1\n",
    "\n",
    "with torch.no_grad():\n",
    "    for _ in range(1):\n",
    "        z_orth = vae_orth.prior.get_latent_time_series(dur,z0=z0,u=u,noise_scale=noise).cpu().numpy()\n",
    "        plt.plot(z_orth[0,:,:,0].T,color='black',alpha=.6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## infer NLB latents from data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae.to_device('cuda')\n",
    "marginal_smoothing = False\n",
    "k = training_params['k']\n",
    "\n",
    "i = 0\n",
    "x = train_dataset.data[i].T.unsqueeze(0)\n",
    "u = train_dataset.stim[i].T.unsqueeze(0)\n",
    "t_held_in = x.shape[2]\n",
    "with torch.no_grad():\n",
    "    Qzs_filt, Qzs_sm, Xs_filt, Xs_sm= vae.predict_NLB(x.cuda(), u=u.cuda(),k=k,t_held_in = t_held_in, t_forward=0,marginal_smoothing=marginal_smoothing)\n",
    "plt.plot(Qzs_filt[0,:,:,0].cpu().T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_repeats = 10\n",
    "batch_size = 64\n",
    "k = 100\n",
    "device = \"cuda:0\"\n",
    "\n",
    "x_t = train_dataset.data.permute(0,2,1).to(device)\n",
    "u_t = train_dataset.stim.permute(0,2,1).to(device)\n",
    "t_held_in = x_t.shape[2]\n",
    "\n",
    "Qzs_filt_repeats = []\n",
    "Qzs_sm_repeats = []\n",
    "Xs_filt_repeats = []\n",
    "Xs_sm_repeats = []\n",
    "for i in range(n_repeats):\n",
    "    with torch.no_grad():\n",
    "        Qzs_filt, Qzs_sm, Xs_filt, Xs_sm = zip(*[\n",
    "            vae.predict_NLB(x_t_chunk, u=u_t_chunk, k=k, t_held_in = t_held_in, t_forward=0)\n",
    "            for x_t_chunk, u_t_chunk in zip(\n",
    "                torch.chunk(x_t, chunks=len(x_t) // batch_size, dim=0),\n",
    "                torch.chunk(u_t, chunks=len(u_t) // batch_size, dim=0),\n",
    "            )\n",
    "        ])\n",
    "        Qzs_filt_repeats.append(torch.cat(Qzs_filt, dim=0).cpu().numpy().mean(axis=-1).transpose(0, 2, 1))\n",
    "        Qzs_sm_repeats.append(torch.cat(Qzs_sm, dim=0).cpu().numpy().mean(axis=-1).transpose(0, 2, 1))\n",
    "        Xs_filt_repeats.append(torch.cat(Xs_filt, dim=0).cpu().numpy().mean(axis=-1).transpose(0, 2, 1))\n",
    "        Xs_sm_repeats.append(torch.cat(Xs_sm, dim=0).cpu().numpy().mean(axis=-1).transpose(0, 2, 1))\n",
    "\n",
    "Qzs_filt = np.mean(Qzs_filt_repeats, axis=0)\n",
    "Qzs_sm = np.mean(Qzs_sm_repeats, axis=0)\n",
    "Xs_filt = np.mean(Xs_filt_repeats, axis=0)\n",
    "Xs_sm = np.mean(Xs_sm_repeats, axis=0)\n",
    "\n",
    "del Qzs_filt_repeats, Qzs_sm_repeats, Xs_filt_repeats, Xs_sm_repeats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc; gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_pos = train_dataset.stim[:, 0, :].detach().cpu().numpy()\n",
    "angles = np.arctan2(target_pos[:, 1], target_pos[:, 0])\n",
    "angles = angles / (2 * np.pi) + 0.5\n",
    "angles = (np.round(angles * 8) % 8) / 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.unique(angles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(target_pos[:, 0], target_pos[:, 1], color=plt.cm.hsv(angles))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Qz_mean = np.empty((len(np.unique(angles)), *Qzs_filt.shape[1:]))\n",
    "for i, angle in enumerate(np.sort(np.unique(angles))):\n",
    "    mask = (angles == angle)\n",
    "    Qz_mean[i] = Qzs_filt[mask].mean(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 0\n",
    "plt_end=35\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for angle_i in range(len(np.sort(np.unique(angles)))):\n",
    "    angle = np.sort(np.unique(angles))[angle_i]\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(Qz_mean[angle_i,plt_start:plt_end,i],alpha = .5, color=plt.cm.hsv(angle))\n",
    "            else:\n",
    "                ax[i,j].plot(\n",
    "                    Qz_mean[angle_i,plt_start:plt_end,i],\n",
    "                    Qz_mean[angle_i,plt_start:plt_end,j],\n",
    "                    alpha = .5, color=plt.cm.hsv(angle)\n",
    "                )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Qz_mean_orth = Qz_mean @ projection_matrix\n",
    "\n",
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 5\n",
    "plt_end=13\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for angle_i in range(len(np.sort(np.unique(angles)))):\n",
    "    angle = np.sort(np.unique(angles))[angle_i]\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(Qz_mean_orth[angle_i,plt_start:plt_end,i],alpha = .5, color=plt.cm.hsv(angle))\n",
    "            else:\n",
    "                ax[i,j].plot(\n",
    "                    Qz_mean_orth[angle_i,plt_start:plt_end,i],\n",
    "                    Qz_mean_orth[angle_i,plt_start:plt_end,j],\n",
    "                    alpha = .5, color=plt.cm.hsv(angle)\n",
    "                )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Qz_mean_orth = Qz_mean @ projection_matrix\n",
    "\n",
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 0\n",
    "plt_end=35\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for angle_i in range(len(np.sort(np.unique(angles)))):\n",
    "    angle = np.sort(np.unique(angles))[angle_i]\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(Qz_mean_orth[angle_i,plt_start:plt_end,i],alpha = .5, color=plt.cm.hsv(angle))\n",
    "            else:\n",
    "                ax[i,j].plot(\n",
    "                    Qz_mean_orth[angle_i,plt_start:plt_end,i],\n",
    "                    Qz_mean_orth[angle_i,plt_start:plt_end,j],\n",
    "                    alpha = .5, color=plt.cm.hsv(angle)\n",
    "                )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Qz_mean_orth = Qz_mean @ projection_matrix\n",
    "\n",
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 12\n",
    "plt_end=35\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for angle_i in range(len(np.sort(np.unique(angles)))):\n",
    "    angle = np.sort(np.unique(angles))[angle_i]\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(Qz_mean_orth[angle_i,plt_start:plt_end,i],alpha = .5, color=plt.cm.hsv(angle))\n",
    "            else:\n",
    "                ax[i,j].plot(\n",
    "                    Qz_mean_orth[angle_i,plt_start:plt_end,i],\n",
    "                    Qz_mean_orth[angle_i,plt_start:plt_end,j],\n",
    "                    alpha = .5, color=plt.cm.hsv(angle)\n",
    "                )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# generate trajectories wth corresponding angles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Qzs_filt.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def project_to_square(theta):\n",
    "    if theta > 180:\n",
    "        theta -= 360\n",
    "    elif theta < -180:\n",
    "        theta += 360\n",
    "    if (theta >= 45 and theta < 135):\n",
    "        y = 1.0\n",
    "        x = y * np.cos(theta / 180 * np.pi) / np.sin(theta / 180 * np.pi)\n",
    "    elif (theta >= 135 or theta < -135):\n",
    "        x = -1.0\n",
    "        y = x * np.sin(theta / 180 * np.pi) / np.cos(theta / 180 * np.pi)\n",
    "    elif (theta >= -135 and theta < -45):\n",
    "        y = -1.0\n",
    "        x = y * np.cos(theta / 180 * np.pi) / np.sin(theta / 180 * np.pi)\n",
    "    else:\n",
    "        x = 1.0\n",
    "        y = x * np.sin(theta / 180 * np.pi) / np.cos(theta / 180 * np.pi)\n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot 3 of the latents\n",
    "offset_start = 0\n",
    "noise_scale = 1\n",
    "# targets = [\n",
    "#     [-1,0],\n",
    "#     [-1,-1],\n",
    "#     [0,-1],\n",
    "#     [1,-1],\n",
    "#     [1,0],\n",
    "#     [1,1],\n",
    "#     [0,1],\n",
    "#     [-1,1],\n",
    "# ]\n",
    "gen_angles = np.arange(-180, 180, 7.5)\n",
    "targets = [project_to_square(ang) for ang in gen_angles]\n",
    "dur = 35 - offset_start\n",
    "t_on = 0\n",
    "t_of = 12 - offset_start\n",
    "R_z=0.05\n",
    "prop_cycle =[plt.cm.hsv(i) for i in np.arange(0, 1, 1/len(targets))]\n",
    "initial_state = Qzs_filt[:, offset_start, :].mean(axis=0)[None, :, None, None]\n",
    "z0 = torch.tile(torch.tensor(initial_state, dtype=torch.float), (len(targets), 1, 1, 1))\n",
    "\n",
    "n_repeats = 50\n",
    "z_all =np.zeros((n_repeats,len(targets),dur,dim_z))\n",
    "for ri in range(n_repeats):\n",
    "    input = np.zeros((len(targets), dur,n_inp))\n",
    "    z=np.zeros((len(targets), dur,dim_z))\n",
    "    for i, target in enumerate(targets):\n",
    "        input[i, t_on:t_of,:]=np.array(target)[None, :]\n",
    "    input = torch.tensor(input, dtype=torch.float).permute(0, 2, 1).cuda()\n",
    "    z = vae.prior.get_latent_time_series(dur,u=input,z0=z0,noise_scale=noise_scale).cpu().numpy()\n",
    "    z_all[ri]=z.transpose(0, 2, 1, 3).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(*np.array(targets).T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot 3 of the latents\n",
    "# noise_scale = 1\n",
    "# targets = [\n",
    "#     [-1,0],\n",
    "#     [-1,-1],\n",
    "#     [0,-1],\n",
    "#     [1,-1],\n",
    "#     [1,0],\n",
    "#     [1,1],\n",
    "#     [0,1],\n",
    "#     [-1,1],\n",
    "# ]\n",
    "# dur = 100\n",
    "# t_on=0\n",
    "# t_of= 100\n",
    "# R_z=0.05\n",
    "# prop_cycle =[plt.cm.hsv(i) for i in np.arange(0, 1, 1/len(targets))]\n",
    "\n",
    "# n_repeats = 10\n",
    "# z_all_long =np.zeros((n_repeats,len(targets),dur,dim_z))\n",
    "# for ri in range(n_repeats):\n",
    "#     input = np.zeros((len(targets), dur,n_inp))\n",
    "#     z=np.zeros((len(targets), dur,dim_z))\n",
    "#     for i, target in enumerate(targets):\n",
    "#         input[i, t_on:t_of,:]=np.array(target)[None, :]\n",
    "#     input = torch.tensor(input, dtype=torch.float).permute(0, 2, 1).cuda()\n",
    "#     z0 = Qzs_filt[np.random.permutation(Qzs_filt.shape[0])[:input.shape[0]], 0, :][:, :, None, None]\n",
    "#     z0 = torch.tensor(z0, dtype=torch.float)\n",
    "#     z = vae.prior.get_latent_time_series(dur,u=input,z0=z0,noise_scale=noise_scale).cpu().numpy()\n",
    "#     z_all_long[ri]=z.transpose(0, 2, 1, 3).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_all.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_mean = np.mean(z_all,axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 0\n",
    "plt_end=35\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for i in range(dim_z):\n",
    "    for j in range(i,dim_z):\n",
    "        ax[i,j].set_prop_cycle('color',prop_cycle)\n",
    "for targ_i in range(len(targets)):\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(z_mean[targ_i,plt_start:plt_end,i],alpha = .5)\n",
    "            else:\n",
    "                ax[i,j].plot(z_mean[targ_i,plt_start:plt_end,i],z_mean[targ_i,plt_start:plt_end,j],alpha = .5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_orth_mean = z_mean @ projection_matrix\n",
    "\n",
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 0\n",
    "plt_end=35\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for i in range(dim_z):\n",
    "    for j in range(i,dim_z):\n",
    "        ax[i,j].set_prop_cycle('color',prop_cycle)\n",
    "for targ_i in range(len(targets)):\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(z_orth_mean[targ_i,plt_start:plt_end,i],alpha = .5)\n",
    "            else:\n",
    "                ax[i,j].plot(z_orth_mean[targ_i,plt_start:plt_end,i],z_orth_mean[targ_i,plt_start:plt_end,j],alpha = .5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 0\n",
    "plt_end=12\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for i in range(dim_z):\n",
    "    for j in range(i,dim_z):\n",
    "        ax[i,j].set_prop_cycle('color',prop_cycle)\n",
    "for targ_i in range(len(targets)):\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(z_orth_mean[targ_i,plt_start:plt_end,i],alpha = .5)\n",
    "            else:\n",
    "                ax[i,j].plot(z_orth_mean[targ_i,plt_start:plt_end,i],z_orth_mean[targ_i,plt_start:plt_end,j],alpha = .5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 12\n",
    "plt_end= 35\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for i in range(dim_z):\n",
    "    for j in range(i,dim_z):\n",
    "        ax[i,j].set_prop_cycle('color',prop_cycle)\n",
    "for targ_i in range(len(targets)):\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(z_orth_mean[targ_i,plt_start:plt_end,i],alpha = .5)\n",
    "            else:\n",
    "                ax[i,j].plot(z_orth_mean[targ_i,plt_start:plt_end,i],z_orth_mean[targ_i,plt_start:plt_end,j],alpha = .5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_orth_all = z_all @ projection_matrix\n",
    "initial_state_proj = initial_state.squeeze() @ projection_matrix\n",
    "plt_start = 2\n",
    "plt_end = 12\n",
    "plt_dim1 = 2\n",
    "plt_dim2 = 0\n",
    "fig, axs = plt.subplots(1, 1, figsize=(1.0, 1.0))\n",
    "for targ_i in range(len(targets)):\n",
    "    axs.plot(\n",
    "        # np.concatenate([initial_state_proj[plt_dim1:plt_dim1+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim1]]),\n",
    "        # np.concatenate([initial_state_proj[plt_dim2:plt_dim2+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim2]]), \n",
    "        z_orth_mean[targ_i,plt_start:plt_end,plt_dim1],\n",
    "        z_orth_mean[targ_i,plt_start:plt_end,plt_dim2],\n",
    "        color=prop_cycle[targ_i], alpha=0.6)\n",
    "    \n",
    "    for trial in range(z_orth_all.shape[0]):\n",
    "        axs.plot(\n",
    "            # np.concatenate([initial_state_proj[plt_dim1:plt_dim1+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim1]]),\n",
    "            # np.concatenate([initial_state_proj[plt_dim2:plt_dim2+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim2]]), \n",
    "            z_orth_all[trial,targ_i,plt_start:plt_end,plt_dim1],\n",
    "            z_orth_all[trial,targ_i,plt_start:plt_end,plt_dim2],\n",
    "            color=prop_cycle[targ_i], alpha=0.1, linewidth=0.2)\n",
    "axs.scatter(z_orth_mean[:,plt_end-1,plt_dim1],z_orth_mean[:,plt_end-1,plt_dim2], \n",
    "            s=20, color=prop_cycle, edgecolor=\"black\", linewidth=0.8, zorder=2)\n",
    "axs.set_xticks([])\n",
    "axs.set_yticks([])\n",
    "axs.set_xlabel(f\"$z_{plt_dim1}$\")\n",
    "axs.set_ylabel(f\"$z_{plt_dim2}$\")\n",
    "\n",
    "xlim = axs.get_xlim()\n",
    "xrange = xlim[1] - xlim[0]\n",
    "axs.set_xlim(xlim[0] - 0.04 * xrange, xlim[1] + 0.04 * xrange)\n",
    "\n",
    "ylim = axs.get_ylim()\n",
    "yrange = ylim[1] - ylim[0]\n",
    "# axs.set_ylim(ylim[0] - 0.04 * yrange, ylim[1] + 0.04 * yrange)\n",
    "\n",
    "axs.set_title(\"Preparatory phase\")\n",
    "plt.savefig(\"../../plots/prep_latents.svg\")\n",
    "plt.savefig(\"../../plots/prep_latents.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=(1.2, 1.2), subplot_kw={'projection': '3d'})\n",
    "\n",
    "z_orth_all = z_all @ projection_matrix\n",
    "initial_state_proj = initial_state.squeeze() @ projection_matrix\n",
    "plt_start = 12\n",
    "plt_end = 35\n",
    "plt_dim1 = 2\n",
    "plt_dim2 = 3\n",
    "plt_dim3 = 4\n",
    "\n",
    "for targ_i in range(len(targets)):\n",
    "    axs.plot(\n",
    "        z_orth_mean[targ_i,plt_start:plt_end,plt_dim1],\n",
    "        z_orth_mean[targ_i,plt_start:plt_end,plt_dim2],\n",
    "        z_orth_mean[targ_i,plt_start:plt_end,plt_dim3],\n",
    "        color=prop_cycle[targ_i], alpha=0.6)\n",
    "        \n",
    "    for trial in range(z_orth_all.shape[0]):\n",
    "        axs.plot(\n",
    "            z_orth_all[trial,targ_i,plt_start:plt_end,plt_dim1],\n",
    "            z_orth_all[trial,targ_i,plt_start:plt_end,plt_dim2],\n",
    "            z_orth_all[trial,targ_i,plt_start:plt_end,plt_dim3],\n",
    "            color=prop_cycle[targ_i], alpha=0.1, linewidth=0.2)\n",
    "        \n",
    "axs.scatter(z_orth_mean[:,plt_start,plt_dim1],z_orth_mean[:,plt_start,plt_dim2], z_orth_mean[:,plt_start,plt_dim3],\n",
    "            s=15, color=prop_cycle, edgecolor=\"black\", linewidth=0.8, zorder=2, alpha=0.8)\n",
    "axs.scatter(z_orth_mean[:,plt_end-1,plt_dim1],z_orth_mean[:,plt_end-1,plt_dim2], z_orth_mean[:,plt_end-1,plt_dim3],\n",
    "            s=15, marker=\"^\", color=prop_cycle, edgecolor=\"black\", linewidth=0.8, zorder=2, alpha=0.8)\n",
    "axs.set_xticks([])\n",
    "axs.set_yticks([])\n",
    "axs.set_zticks([])\n",
    "axs.set_xlabel(f\"$z_{plt_dim1}$\")\n",
    "axs.set_ylabel(f\"$z_{plt_dim2}$\")\n",
    "axs.set_zlabel(f\"$z_{plt_dim3}$\")\n",
    "axs.xaxis.labelpad=-15\n",
    "axs.yaxis.labelpad=-15\n",
    "axs.zaxis.labelpad=-16\n",
    "\n",
    "axs.set_title(\"Movement phase\", pad=-10)\n",
    "plt.savefig(\"../../plots/movement_latents.svg\")\n",
    "plt.savefig(\"../../plots/movement_latents.pdf\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.lines import Line2D\n",
    "from matplotlib.patches import Patch\n",
    "from matplotlib.markers import MarkerStyle\n",
    "\n",
    "fig = plt.figure(figsize=(1.0, 1.0))\n",
    "legend_elements = [Line2D([], [], color='gray', marker='.', linestyle='None',\n",
    "                          markeredgecolor=\"black\", markeredgewidth=1.6, \n",
    "                          markersize=12, label='Go cue'),\n",
    "                   Line2D([], [], color='gray', marker='^', linestyle='None',\n",
    "                          markeredgecolor=\"black\", markeredgewidth=1.6, \n",
    "                          markersize=7, label='Reach end')]\n",
    "plt.legend(handles=legend_elements)\n",
    "plt.axis('off')\n",
    "plt.savefig(\"../../plots/latents_legend.svg\")\n",
    "plt.savefig(\"../../plots/latents_legend.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## decoded reaches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = DATA_ROOT / \"mc_maze_input\" / \"pos\" / \"eval_target_20ms.h5\"\n",
    "with h5py.File(data_path, \"r\") as h5f:\n",
    "    train_behavior = h5f[\"mc_maze_20\"][\"train_behavior\"][()]\n",
    "    eval_behavior = h5f[\"mc_maze_20\"][\"eval_behavior\"][()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import Ridge\n",
    "\n",
    "flatten2d = lambda arr: arr.reshape(-1, arr.shape[-1])\n",
    "rate_decoder = Ridge(alpha=1e-6)\n",
    "\n",
    "rate_decoder.fit(\n",
    "    flatten2d(Xs_filt),\n",
    "    flatten2d(train_behavior),\n",
    ")\n",
    "\n",
    "print(rate_decoder.score(flatten2d(Xs_filt), flatten2d(train_behavior)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_decoder = Ridge(alpha=1e-6)\n",
    "\n",
    "latent_decoder.fit(\n",
    "    flatten2d(Qzs_filt),\n",
    "    flatten2d(train_behavior),\n",
    ")\n",
    "\n",
    "print(latent_decoder.score(flatten2d(Qzs_filt), flatten2d(train_behavior)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:0\"\n",
    "        \n",
    "x_all = np.empty((*z_all.shape[:-1], 182))\n",
    "for i in range(z_all.shape[0]):\n",
    "    with torch.no_grad():\n",
    "        x = vae.obs_rectify(\n",
    "            vae.prior.get_observation(\n",
    "                z=torch.tensor(z_all[i], dtype=torch.float, device=device).unsqueeze(-1).permute(0, 2, 1, 3)\n",
    "            )\n",
    "        ).cpu().numpy()\n",
    "        x_all[i] = x.transpose(0, 2, 1, 3).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_all.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "behavior_all = rate_decoder.predict(flatten2d(x_all)).reshape(*x_all.shape[:-1], -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "position_all = np.cumsum(behavior_all, axis=2) / 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "position_mean = position_all.mean(axis=0)\n",
    "fig, axs = plt.subplots(1, 1, figsize=(1.0, 1.0))\n",
    "for targ_i in range(len(targets)):\n",
    "    axs.plot(\n",
    "        position_mean[targ_i,:,0],\n",
    "        position_mean[targ_i,:,1],\n",
    "        color=prop_cycle[targ_i], alpha=0.6,\n",
    "        # linestyle=(\"--\" if targ_i == 2 else \"-\"),\n",
    "    )\n",
    "        \n",
    "    for trial in range(position_all.shape[0]):\n",
    "        axs.plot(\n",
    "            position_all[trial,targ_i,:,0],\n",
    "            position_all[trial,targ_i,:,1],\n",
    "            color=prop_cycle[targ_i], alpha=0.2, linewidth=0.4)\n",
    "# axs.scatter(position_mean[:2,-1,0], position_mean[:2,-1,1], color=prop_cycle[:2], s=20, marker=\"s\", edgecolor=\"black\", linewidth=0.8, zorder=2)\n",
    "# axs.scatter(position_mean[2,-1,0], position_mean[2,-1,1], color=prop_cycle[2], marker=\"*\", s=28, edgecolor=\"black\", linewidth=0.8, zorder=2)\n",
    "# axs.scatter(position_mean[3:,-1,0], position_mean[3:,-1,1], color=prop_cycle[3:], s=20, marker=\"s\", edgecolor=\"black\", linewidth=0.8, zorder=2)\n",
    "axs.set_xticks([])\n",
    "axs.set_yticks([])\n",
    "axs.set_xlabel(\"x position\")\n",
    "axs.set_ylabel(\"y position\")\n",
    "axs.set_title(\"decoded reaches\")\n",
    "\n",
    "# xlim = axs.get_xlim()\n",
    "# xrange = xlim[1] - xlim[0]\n",
    "# axs.set_xlim(xlim[0] + 0.1 * xrange, xlim[1] - 0.1 * xrange)\n",
    "# axs.set_xlim(-150, 150)\n",
    "\n",
    "# xlim = axs.get_xlim()\n",
    "# xrange = xlim[1] - xlim[0]\n",
    "# axs.set_xlim(xlim[0] - 0.1 * xrange, xlim[1] + 0.1 * xrange)\n",
    "# axs.set_ylim(-100, 100)\n",
    "\n",
    "plt.savefig(\"../../plots/decoded_reaches.svg\")\n",
    "plt.savefig(\"../../plots/decoded_reaches.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.lines import Line2D\n",
    "from matplotlib.patches import Patch\n",
    "from matplotlib.markers import MarkerStyle\n",
    "\n",
    "fig = plt.figure(figsize=(1.0, 1.0))\n",
    "legend_elements = [Line2D([], [], color='gray', marker='s', linestyle='None',\n",
    "                          markeredgecolor=\"black\", markeredgewidth=1.6, \n",
    "                          markersize=7, label='Seen condition'),\n",
    "                   Line2D([], [], color='gray', marker=\"*\", linestyle='None',\n",
    "                          markeredgecolor=\"black\", markeredgewidth=1.0, \n",
    "                          markersize=9, label='Unseen condition')]\n",
    "plt.legend(handles=legend_elements)\n",
    "plt.axis('off')\n",
    "plt.savefig(\"../../plots/reaches_legend.svg\")\n",
    "plt.savefig(\"../../plots/reaches_legend.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_all.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## spike stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot 3 of the latents\n",
    "\n",
    "device = \"cuda:0\"\n",
    "u = train_dataset.stim.permute(0,2,1).to(device)\n",
    "initial_states = Qzs_filt[:, 0, :]\n",
    "initial_states_mean = initial_states.mean(axis=0)\n",
    "initial_states_std = initial_states.std(axis=0)\n",
    "\n",
    "z_sim =np.zeros((u.shape[0],dur,dim_z))\n",
    "for ri in range(u.shape[0]):\n",
    "    input = u[ri][None, :, :]\n",
    "    # z0 = np.random.randn(dim_z) * initial_states_std + initial_states_mean\n",
    "    z0 = torch.tensor(initial_states[ri], dtype=torch.float, device=device)[None, :, None, None]\n",
    "    # z0 = torch.tensor(z0, dtype=torch.float, device=device)[None, :, None, None]\n",
    "    z = vae.prior.get_latent_time_series(dur,u=input,z0=z0,noise_scale=noise_scale).cpu().numpy()\n",
    "    z_sim[ri]=z.transpose(0, 2, 1, 3).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_sim = np.empty((*z_sim.shape[:-1], 182))\n",
    "for i in range(z_sim.shape[0]):\n",
    "    with torch.no_grad():\n",
    "        x = vae.obs_rectify(\n",
    "            vae.prior.get_observation(\n",
    "                z=torch.tensor(z_sim[i], dtype=torch.float, device=device).unsqueeze(-1).unsqueeze(0).permute(0, 2, 1, 3)\n",
    "            )\n",
    "        ).cpu().numpy()\n",
    "        x_sim[i] = x.transpose(0, 2, 1, 3).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def coefficient_of_variation(data):\n",
    "    \"\"\"Calculate the coefficient of variation for each neuron.\"\"\"\n",
    "    means = np.mean(data, axis=0)\n",
    "    std_devs = np.std(data, axis=0)\n",
    "    return np.where(means == 0, 0, std_devs / means)  # Handle divide by zero by setting CV to 0\n",
    "\n",
    "def binned_spikes_to_times(binned_spikes, bin_size=0.02):\n",
    "    idxs = np.nonzero(binned_spikes)[0]\n",
    "    spike_times = np.repeat(idxs, binned_spikes[idxs].astype(int))\n",
    "    spike_times = spike_times.astype(float)\n",
    "    for idx in idxs:\n",
    "        if binned_spikes[idx] > 1:\n",
    "            spike_times[spike_times == idx] += np.arange(binned_spikes[idx]) / binned_spikes[idx]\n",
    "    spike_times *= bin_size\n",
    "    return spike_times\n",
    "\n",
    "def compute_isi_stats(spikes, bin_size=0.02):\n",
    "    n_trials, n_timesteps, n_neurons = spikes.shape\n",
    "    isi_means = []\n",
    "    isi_stds = []\n",
    "    isi_cvs = []\n",
    "    for neuron in range(n_neurons):\n",
    "        isis = np.concatenate([np.diff(binned_spikes_to_times(spikes[i,:,neuron], bin_size=bin_size)) for i in range(n_trials)])\n",
    "        isi_means.append(np.mean(isis))\n",
    "        isi_stds.append(np.std(isis))\n",
    "        isi_cvs.append(coefficient_of_variation(isis))\n",
    "    return np.array(isi_means), np.array(isi_stds), np.array(isi_cvs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ground truth overall stats\n",
    "spikes = train_dataset.data.detach().cpu().numpy()\n",
    "\n",
    "sr_mean_all = spikes.mean(axis=(0,1))\n",
    "sr_std_all = spikes.std(axis=(0,1))\n",
    "isi_mean_all, isi_std_all, isi_cv_all = compute_isi_stats(spikes)\n",
    "\n",
    "fig, axs = plt.subplots(1, 5, figsize=(5,1))\n",
    "axs[0].hist(sr_mean_all)\n",
    "axs[0].set_title(\"mean spike rate\")\n",
    "\n",
    "axs[1].hist(sr_std_all)\n",
    "axs[1].set_title(\"std spike rate\")\n",
    "\n",
    "axs[2].hist(isi_mean_all)\n",
    "axs[2].set_title(\"mean isi\")\n",
    "\n",
    "axs[3].hist(isi_std_all)\n",
    "axs[3].set_title(\"std isi\")\n",
    "\n",
    "axs[4].hist(isi_cv_all)\n",
    "axs[4].set_title(\"isi cv\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ground truth overall stats\n",
    "test_spikes = eval_dataset.data.detach().cpu().numpy()\n",
    "\n",
    "test_sr_mean_all = test_spikes.mean(axis=(0,1))\n",
    "test_sr_std_all = test_spikes.std(axis=(0,1))\n",
    "test_isi_mean_all, test_isi_std_all, test_isi_cv_all = compute_isi_stats(test_spikes)\n",
    "\n",
    "fig, axs = plt.subplots(1, 5, figsize=(5,1))\n",
    "axs[0].hist(test_sr_mean_all)\n",
    "axs[0].set_title(\"mean spike rate\")\n",
    "\n",
    "axs[1].hist(test_sr_std_all)\n",
    "axs[1].set_title(\"std spike rate\")\n",
    "\n",
    "axs[2].hist(test_isi_mean_all)\n",
    "axs[2].set_title(\"mean isi\")\n",
    "\n",
    "axs[3].hist(test_isi_std_all)\n",
    "axs[3].set_title(\"std isi\")\n",
    "\n",
    "axs[4].hist(test_isi_cv_all)\n",
    "axs[4].set_title(\"isi cv\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inferred spikes\n",
    "\n",
    "inferred_spikes = np.random.poisson(x_sim)\n",
    "\n",
    "inf_sr_mean_all = inferred_spikes.mean(axis=(0,1))\n",
    "inf_sr_std_all = inferred_spikes.std(axis=(0,1))\n",
    "inf_isi_mean_all, inf_isi_std_all, inf_isi_cv_all = compute_isi_stats(inferred_spikes)\n",
    "\n",
    "fig, axs = plt.subplots(1, 5, figsize=(5,1))\n",
    "axs[0].hist(inf_sr_mean_all)\n",
    "axs[0].set_title(\"mean spike rate\")\n",
    "\n",
    "axs[1].hist(inf_sr_std_all)\n",
    "axs[1].set_title(\"std spike rate\")\n",
    "\n",
    "axs[2].hist(inf_isi_mean_all)\n",
    "axs[2].set_title(\"mean isi\")\n",
    "\n",
    "axs[3].hist(inf_isi_std_all)\n",
    "axs[3].set_title(\"std isi\")\n",
    "\n",
    "axs[4].hist(inf_isi_cv_all)\n",
    "axs[4].set_title(\"isi cv\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_vals = [sr_mean_all, sr_std_all, isi_mean_all, isi_std_all, isi_cv_all]\n",
    "inf_vals = [inf_sr_mean_all, inf_sr_std_all, inf_isi_mean_all, inf_isi_std_all, inf_isi_cv_all]\n",
    "labels = [\"mean spike rate\", \"std spike rate\", \"mean isi\", \"std isi\", \"isi cv\"]\n",
    "\n",
    "fig, axs = plt.subplots(1, 5, figsize=(7.5,1.8))\n",
    "for i, true, inf, label in zip(range(len(labels)), true_vals, inf_vals, labels):\n",
    "    axs[i].scatter(inf, true)\n",
    "    xlim = axs[i].get_xlim()\n",
    "    ylim = axs[i].get_ylim()\n",
    "    axs[i].plot(np.arange(-0.5, 2.0), np.arange(-0.5, 2.0), linestyle='--', color='black')\n",
    "    axs[i].set_xlim(0, max(xlim[1], ylim[1]))\n",
    "    axs[i].set_ylim(0, max(xlim[1], ylim[1]))\n",
    "    axs[i].set_title(label)\n",
    "    axs[i].set_xlabel(\"inferred\")\n",
    "axs[0].set_ylabel(\"true\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_vals = [sr_mean_all, sr_std_all, isi_mean_all, isi_std_all, isi_cv_all]\n",
    "inf_vals = [test_sr_mean_all, test_sr_std_all, test_isi_mean_all, test_isi_std_all, test_isi_cv_all]\n",
    "labels = [\"mean spike rate\", \"std spike rate\", \"mean isi\", \"std isi\", \"isi cv\"]\n",
    "\n",
    "fig, axs = plt.subplots(1, 5, figsize=(7.5,1.8))\n",
    "for i, true, inf, label in zip(range(len(labels)), true_vals, inf_vals, labels):\n",
    "    axs[i].scatter(inf, true)\n",
    "    xlim = axs[i].get_xlim()\n",
    "    ylim = axs[i].get_ylim()\n",
    "    axs[i].plot(np.arange(-0.5, 2.0), np.arange(-0.5, 2.0), linestyle='--', color='black')\n",
    "    axs[i].set_xlim(0, max(xlim[1], ylim[1]))\n",
    "    axs[i].set_ylim(0, max(xlim[1], ylim[1]))\n",
    "    axs[i].set_title(label)\n",
    "    axs[i].set_xlabel(\"inferred\")\n",
    "axs[0].set_ylabel(\"true\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# conditional true spike stats\n",
    "\n",
    "spikes = train_dataset.data.detach().cpu().numpy()\n",
    "target_pos = train_dataset.stim[:, 0, :].detach().cpu().numpy()\n",
    "angles = np.arctan2(target_pos[:, 1], target_pos[:, 0])\n",
    "angles = angles / (2 * np.pi) + 0.5\n",
    "angles = (np.round(angles * 8) % 8) / 8\n",
    "\n",
    "sr_mean_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "sr_std_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "isi_mean_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "isi_std_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "isi_cv_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "for i, angle in enumerate(np.sort(np.unique(angles))):\n",
    "    mask = (angles == angle)\n",
    "    sr_mean = spikes[mask].mean(axis=(0,1))\n",
    "    sr_std = spikes[mask].std(axis=(0,1))\n",
    "    isi_mean, isi_std, isi_cv = compute_isi_stats(spikes[mask])\n",
    "\n",
    "    sr_mean_cond[i] = sr_mean\n",
    "    sr_std_cond[i] = sr_std\n",
    "    isi_mean_cond[i] = isi_mean\n",
    "    isi_std_cond[i] = isi_std\n",
    "    isi_cv_cond[i] = isi_cv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# conditional test true spike stats\n",
    "\n",
    "test_spikes = eval_dataset.data.detach().cpu().numpy()\n",
    "test_target_pos = eval_dataset.stim[:, 0, :].detach().cpu().numpy()\n",
    "test_angles = np.arctan2(test_target_pos[:, 1], test_target_pos[:, 0])\n",
    "test_angles = test_angles / (2 * np.pi) + 0.5\n",
    "test_angles = (np.round(test_angles * 8) % 8) / 8\n",
    "\n",
    "test_sr_mean_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[-1]))\n",
    "test_sr_std_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[-1]))\n",
    "test_isi_mean_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[-1]))\n",
    "test_isi_std_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[-1]))\n",
    "test_isi_cv_cond = np.empty((len(np.unique(test_angles)), test_spikes.shape[-1]))\n",
    "for i, angle in enumerate(np.sort(np.unique(test_angles))):\n",
    "    mask = (test_angles == angle)\n",
    "    test_sr_mean = test_spikes[mask].mean(axis=(0,1))\n",
    "    test_sr_std = test_spikes[mask].std(axis=(0,1))\n",
    "    test_isi_mean, test_isi_std, test_isi_cv = compute_isi_stats(test_spikes[mask])\n",
    "\n",
    "    test_sr_mean_cond[i] = test_sr_mean\n",
    "    test_sr_std_cond[i] = test_sr_std\n",
    "    test_isi_mean_cond[i] = test_isi_mean\n",
    "    test_isi_std_cond[i] = test_isi_std\n",
    "    test_isi_cv_cond[i] = test_isi_cv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 5, figsize=(7.5,1.5))\n",
    "\n",
    "axs[0].imshow(np.corrcoef(sr_mean_cond), vmin=0, vmax=1)\n",
    "axs[0].set_title(\"mean spike rate\")\n",
    "\n",
    "axs[1].imshow(np.corrcoef(sr_std_cond), vmin=0, vmax=1)\n",
    "axs[1].set_title(\"std spike rate\")\n",
    "\n",
    "mask = np.any(np.isnan(isi_mean_cond), axis=0)\n",
    "axs[2].imshow(np.corrcoef(isi_mean_cond[:, ~mask]), vmin=0, vmax=1)\n",
    "axs[2].set_title(\"mean isi\")\n",
    "\n",
    "mask = np.any(np.isnan(isi_std_cond), axis=0)\n",
    "axs[3].imshow(np.corrcoef(isi_std_cond[:, ~mask]), vmin=0, vmax=1)\n",
    "axs[3].set_title(\"std isi\")\n",
    "\n",
    "mask = np.any(np.isnan(isi_cv_cond), axis=0)\n",
    "axs[4].imshow(np.corrcoef(isi_cv_cond[:, ~mask]), vmin=0, vmax=1)\n",
    "axs[4].set_title(\"isi cv\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inf_sr_mean_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "inf_sr_std_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "inf_isi_mean_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "inf_isi_std_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "inf_isi_cv_cond = np.empty((len(np.unique(angles)), spikes.shape[-1]))\n",
    "for i, angle in enumerate(np.sort(np.unique(angles))):\n",
    "    mask = (angles == angle)\n",
    "    inf_sr_mean = inferred_spikes[mask].mean(axis=(0,1))\n",
    "    inf_sr_std = inferred_spikes[mask].std(axis=(0,1))\n",
    "    inf_isi_mean, inf_isi_std, inf_isi_cv = compute_isi_stats(inferred_spikes[mask])\n",
    "\n",
    "    inf_sr_mean_cond[i] = inf_sr_mean\n",
    "    inf_sr_std_cond[i] = inf_sr_std\n",
    "    inf_isi_mean_cond[i] = inf_isi_mean\n",
    "    inf_isi_std_cond[i] = inf_isi_std\n",
    "    inf_isi_cv_cond[i] = inf_isi_cv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 5, figsize=(7.5,1.5))\n",
    "\n",
    "axs[0].imshow(np.corrcoef(inf_sr_mean_cond), vmin=0, vmax=1)\n",
    "axs[0].set_title(\"mean spike rate\")\n",
    "\n",
    "axs[1].imshow(np.corrcoef(inf_sr_std_cond), vmin=0, vmax=1)\n",
    "axs[1].set_title(\"std spike rate\")\n",
    "\n",
    "mask = np.any(np.isnan(inf_isi_mean_cond), axis=0)\n",
    "axs[2].imshow(np.corrcoef(inf_isi_mean_cond[:, ~mask]), vmin=0, vmax=1)\n",
    "axs[2].set_title(\"mean isi\")\n",
    "\n",
    "mask = np.any(np.isnan(inf_isi_std_cond), axis=0)\n",
    "axs[3].imshow(np.corrcoef(inf_isi_std_cond[:, ~mask]), vmin=0, vmax=1)\n",
    "axs[3].set_title(\"std isi\")\n",
    "\n",
    "mask = np.any(np.isnan(inf_isi_cv_cond), axis=0)\n",
    "axs[4].imshow(np.corrcoef(inf_isi_cv_cond[:, ~mask]), vmin=0, vmax=1)\n",
    "axs[4].set_title(\"isi cv\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rmse(pred, target):\n",
    "    return np.sqrt(np.mean(np.square(pred - target)))\n",
    "\n",
    "def nanrmse(pred, target, warn_extra_nan=False):\n",
    "    mask = np.isnan(target)\n",
    "    if np.any(np.isnan(pred[~mask])) and warn_extra_nan:\n",
    "        print(\"nans found in predictions for valid tragets\")\n",
    "    mask = np.logical_or(mask, np.isnan(pred))\n",
    "    return np.sqrt(np.mean(np.square(pred[~mask] - target[~mask])))\n",
    "\n",
    "sr_mean_all_rmse = rmse(inf_sr_mean_all, sr_mean_all)\n",
    "sr_std_all_rmse = rmse(inf_sr_std_all, sr_std_all)\n",
    "isi_mean_all_rmse = nanrmse(inf_isi_mean_all, isi_mean_all)\n",
    "isi_std_all_rmse = nanrmse(inf_isi_std_all, isi_std_all)\n",
    "isi_cv_all_rmse = nanrmse(inf_isi_cv_all, isi_cv_all)\n",
    "\n",
    "print(f\"{sr_mean_all_rmse=}\\n{sr_std_all_rmse=}\\n{isi_mean_all_rmse=}\\n{isi_std_all_rmse=}\\n{isi_cv_all_rmse=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_conds = sr_mean_cond.shape[0]\n",
    "\n",
    "sr_mean_cond_rmse = []\n",
    "sr_std_cond_rmse = []\n",
    "isi_mean_cond_rmse = []\n",
    "isi_std_cond_rmse = []\n",
    "isi_cv_cond_rmse = []\n",
    "for cond in range(n_conds):\n",
    "    sr_mean_cond_rmse.append(rmse(inf_sr_mean_cond[cond], sr_mean_cond[cond]))\n",
    "    sr_std_cond_rmse.append(rmse(inf_sr_std_cond[cond], sr_std_cond[cond]))\n",
    "    isi_mean_cond_rmse.append(nanrmse(inf_isi_mean_cond[cond], isi_mean_cond[cond]))\n",
    "    isi_std_cond_rmse.append(nanrmse(inf_isi_std_cond[cond], isi_std_cond[cond]))\n",
    "    isi_cv_cond_rmse.append(nanrmse(inf_isi_cv_cond[cond], isi_cv_cond[cond]))\n",
    "\n",
    "sr_mean_cond_rmse = np.mean(sr_mean_cond_rmse)\n",
    "sr_std_cond_rmse = np.mean(sr_std_cond_rmse)\n",
    "isi_mean_cond_rmse = np.mean(isi_mean_cond_rmse)\n",
    "isi_std_cond_rmse = np.mean(isi_std_cond_rmse)\n",
    "isi_cv_cond_rmse = np.mean(isi_cv_cond_rmse)\n",
    "\n",
    "print(f\"{sr_mean_cond_rmse=}\\n{sr_std_cond_rmse=}\\n{isi_mean_cond_rmse=}\\n{isi_std_cond_rmse=}\\n{isi_cv_cond_rmse=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_conds = sr_mean_cond.shape[0]\n",
    "\n",
    "bl_sr_mean_cond_rmse = []\n",
    "bl_sr_std_cond_rmse = []\n",
    "bl_isi_mean_cond_rmse = []\n",
    "bl_isi_std_cond_rmse = []\n",
    "bl_isi_cv_cond_rmse = []\n",
    "for cond in range(n_conds):\n",
    "    bl_sr_mean_cond_rmse.append(rmse(sr_mean_all, sr_mean_cond[cond]))\n",
    "    bl_sr_std_cond_rmse.append(rmse(sr_std_all, sr_std_cond[cond]))\n",
    "    bl_isi_mean_cond_rmse.append(nanrmse(isi_mean_all, isi_mean_cond[cond]))\n",
    "    bl_isi_std_cond_rmse.append(nanrmse(isi_std_all, isi_std_cond[cond]))\n",
    "    bl_isi_cv_cond_rmse.append(nanrmse(isi_cv_all, isi_cv_cond[cond]))\n",
    "\n",
    "bl_sr_mean_cond_rmse = np.mean(bl_sr_mean_cond_rmse)\n",
    "bl_sr_std_cond_rmse = np.mean(bl_sr_std_cond_rmse)\n",
    "bl_isi_mean_cond_rmse = np.mean(bl_isi_mean_cond_rmse)\n",
    "bl_isi_std_cond_rmse = np.mean(bl_isi_std_cond_rmse)\n",
    "bl_isi_cv_cond_rmse = np.mean(bl_isi_cv_cond_rmse)\n",
    "\n",
    "print(f\"{bl_sr_mean_cond_rmse=}\\n{bl_sr_std_cond_rmse=}\\n{bl_isi_mean_cond_rmse=}\\n{bl_isi_std_cond_rmse=}\\n{bl_isi_cv_cond_rmse=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_cond = 0\n",
    "comp_cond = 3\n",
    "true_vals = [sr_mean_cond[eval_cond], sr_std_cond[eval_cond], isi_mean_cond[eval_cond], isi_std_cond[eval_cond]]\n",
    "inf_vals = [inf_sr_mean_cond[eval_cond], inf_sr_std_cond[eval_cond], inf_isi_mean_cond[eval_cond], inf_isi_std_cond[eval_cond]]\n",
    "inf2_vals = [sr_mean_cond[comp_cond], sr_std_cond[comp_cond], isi_mean_cond[comp_cond], isi_std_cond[comp_cond]]\n",
    "labels = [\"mean spike rate\", \"std spike rate\", \"mean isi\", \"std isi\"]\n",
    "\n",
    "fig, axs = plt.subplots(1, 4, figsize=(4.8,1.4))\n",
    "for i, true, inf, inf2, label in zip(range(len(labels)), true_vals, inf_vals, inf2_vals, labels):\n",
    "    axs[i].scatter(true, inf2, color='darkred', alpha=0.4)\n",
    "    axs[i].scatter(true, inf, color='mediumblue', alpha=0.4)\n",
    "    xlim = axs[i].get_xlim()\n",
    "    ylim = axs[i].get_ylim()\n",
    "    axs[i].plot(np.arange(-0.5, 2.0), np.arange(-0.5, 2.0), linestyle='--', color='black')\n",
    "    axs[i].set_xlim(0, max(xlim[1], ylim[1]))\n",
    "    axs[i].set_ylim(0, max(xlim[1], ylim[1]))\n",
    "    axs[i].set_title(label)\n",
    "    axs[i].set_xlabel(\"condition 0\")\n",
    "axs[0].set_ylabel(\"model / condition 3\")\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../../plots/cond_spike_isi_stats.svg\")\n",
    "plt.savefig(\"../../plots/cond_spike_isi_stats.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_cond = 0\n",
    "comp_cond = 3\n",
    "true_vals = [sr_mean_cond[eval_cond], sr_std_cond[eval_cond]]\n",
    "inf_vals = [inf_sr_mean_cond[eval_cond], inf_sr_std_cond[eval_cond]]\n",
    "inf2_vals = [sr_mean_cond[comp_cond], sr_std_cond[comp_cond]]\n",
    "labels = [\"mean spike rate\", \"std spike rate\"]\n",
    "\n",
    "fig, axs = plt.subplots(1, 2, figsize=(2.4,1.4))\n",
    "for i, true, inf, inf2, label in zip(range(len(labels)), true_vals, inf_vals, inf2_vals, labels):\n",
    "    axs[i].scatter(true, inf2, color='darkred', alpha=0.4)\n",
    "    axs[i].scatter(true, inf, color='mediumblue', alpha=0.4)\n",
    "    xlim = axs[i].get_xlim()\n",
    "    ylim = axs[i].get_ylim()\n",
    "    axs[i].plot(np.arange(-0.5, 2.0), np.arange(-0.5, 2.0), linestyle='--', color='black')\n",
    "    axs[i].set_xlim(0, max(xlim[1], ylim[1]))\n",
    "    axs[i].set_ylim(0, max(xlim[1], ylim[1]))\n",
    "    axs[i].set_title(label)\n",
    "    axs[i].set_xlabel(\"condition 0\")\n",
    "axs[0].set_ylabel(\"model / condition 3\")\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../../plots/cond_spike_stats.svg\")\n",
    "plt.savefig(\"../../plots/cond_spike_stats.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.lines import Line2D\n",
    "from matplotlib.patches import Patch\n",
    "from matplotlib.markers import MarkerStyle\n",
    "\n",
    "fig = plt.figure(figsize=(1.0, 1.0))\n",
    "legend_elements = [Line2D([], [], color='mediumblue', marker=\".\", linestyle='None',\n",
    "                          markersize=7, label='model'),\n",
    "                   Line2D([], [], color='darkred', marker=\".\", linestyle='None',\n",
    "                          markersize=7, label='condition 3')]\n",
    "plt.legend(handles=legend_elements)\n",
    "plt.axis('off')\n",
    "\n",
    "plt.savefig(\"../../plots/cond_spike_stats_legend.svg\")\n",
    "plt.savefig(\"../../plots/cond_spike_stats_legend.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(np.corrcoef(sr_mean_cond, np.random.randn(5, 182)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(2.4, 1.4))\n",
    "n_cond = sr_mean_cond.shape[0]\n",
    "\n",
    "axs[0].imshow(np.corrcoef(sr_mean_cond, test_sr_mean_cond)[n_cond:, :n_cond], vmin=0.0, vmax=1, cmap=\"coolwarm\")\n",
    "axs[0].set_title(\"true\")\n",
    "axs[0].set_xticks([])\n",
    "axs[0].set_yticks([])\n",
    "axs[0].set_xlabel(\"condition\")\n",
    "axs[0].set_ylabel(\"condition\")\n",
    "\n",
    "axs[1].imshow(np.corrcoef(inf_sr_mean_cond, test_sr_mean_cond)[n_cond:, :n_cond], vmin=0.0, vmax=1, cmap=\"coolwarm\")\n",
    "axs[1].set_title(\"model\")\n",
    "axs[1].set_xticks([])\n",
    "axs[1].set_yticks([])\n",
    "axs[1].set_xlabel(\"condition\")\n",
    "plt.suptitle(\"     Spike rate correlation\", y=0.95)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../../plots/cond_spike_corr.svg\")\n",
    "plt.savefig(\"../../plots/cond_spike_corr.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(2.4, 1.4))\n",
    "n_cond = sr_mean_cond.shape[0]\n",
    "\n",
    "test_mask = np.any(np.isnan(test_isi_cv_cond), axis=0)\n",
    "mask = np.any(np.isnan(isi_cv_cond), axis=0)\n",
    "inf_mask = np.any(np.isnan(inf_isi_cv_cond), axis=0)\n",
    "\n",
    "print(test_mask.sum(), mask.sum(), inf_mask.sum())\n",
    "all_mask = np.logical_or(np.logical_or(test_mask, mask), inf_mask)\n",
    "print(all_mask.sum())\n",
    "\n",
    "axs[0].imshow(\n",
    "    np.corrcoef(isi_cv_cond[:, ~all_mask], test_isi_cv_cond[:, ~all_mask])[n_cond:, :n_cond], \n",
    "    vmin=0.0, vmax=1.0, cmap=\"coolwarm\",\n",
    ")\n",
    "# axs[0].set_title(\"true\")\n",
    "axs[0].set_xticks([])\n",
    "axs[0].set_yticks([])\n",
    "axs[0].set_xlabel(\"train conditions\")\n",
    "axs[0].set_ylabel(\"test conditions\")\n",
    "\n",
    "axs[1].imshow(\n",
    "    np.corrcoef(inf_isi_cv_cond[:, ~all_mask], test_isi_cv_cond[:, ~all_mask])[n_cond:, :n_cond],\n",
    "    vmin=0.0, vmax=1.0, cmap=\"coolwarm\",\n",
    ")\n",
    "# axs[1].set_title(\"model\")\n",
    "axs[1].set_xticks([])\n",
    "axs[1].set_yticks([])\n",
    "axs[1].set_xlabel(\"model condition\")\n",
    "plt.suptitle(\"     isi cv correlation\", y=0.95)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../../plots/cond_isicv_corr.svg\")\n",
    "plt.savefig(\"../../plots/cond_isicv_corr.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(figsize=(0.25 / 3, 3 / 3))\n",
    "cmap = mpl.cm.coolwarm\n",
    "norm = mpl.colors.Normalize(vmin=0, vmax=1)\n",
    "\n",
    "cb1 = mpl.colorbar.ColorbarBase(axs, cmap=cmap,\n",
    "                                norm=norm,\n",
    "                                orientation='vertical')\n",
    "\n",
    "plt.savefig(\"../../plots/corr_colorbar.svg\")\n",
    "plt.savefig(\"../../plots/corr_colorbar.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## interpolation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot 3 of the latents\n",
    "noise_scale = 1\n",
    "targets = [\n",
    "    [-1,0],\n",
    "    [-1,-1],\n",
    "    [0,-1],\n",
    "    [1,-1],\n",
    "    [1,0],\n",
    "    [1,1],\n",
    "    [0,1],\n",
    "    [-1,1],\n",
    "]\n",
    "dur = 200\n",
    "t_on = 0\n",
    "t_of = 200\n",
    "R_z=0.05\n",
    "prop_cycle =[plt.cm.hsv(i) for i in np.arange(0, 1, 1/len(targets))]\n",
    "initial_state = Qzs_filt[:, offset_start, :].mean(axis=0)[None, :, None, None]\n",
    "z0 = torch.tile(torch.tensor(initial_state, dtype=torch.float), (len(targets), 1, 1, 1))\n",
    "\n",
    "n_repeats = 50\n",
    "z_long_prep_all =np.zeros((n_repeats,len(targets),dur,dim_z))\n",
    "for ri in range(n_repeats):\n",
    "    input = np.zeros((len(targets), dur,n_inp))\n",
    "    z=np.zeros((len(targets), dur,dim_z))\n",
    "    for i, target in enumerate(targets):\n",
    "        input[i, t_on:t_of,:]=np.array(target)[None, :]\n",
    "    input = torch.tensor(input, dtype=torch.float).permute(0, 2, 1).cuda()\n",
    "    z = vae.prior.get_latent_time_series(dur,u=input,z0=z0,noise_scale=noise_scale).cpu().numpy()\n",
    "    z_long_prep_all[ri]=z.transpose(0, 2, 1, 3).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_long_prep_mean = z_long_prep_all.mean(axis=0)\n",
    "\n",
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 0\n",
    "plt_end=200\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for i in range(dim_z):\n",
    "    for j in range(i,dim_z):\n",
    "        ax[i,j].set_prop_cycle('color',prop_cycle)\n",
    "for targ_i in range(len(targets)):\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(z_long_prep_mean[targ_i,plt_start:plt_end,i],alpha = .5)\n",
    "            else:\n",
    "                ax[i,j].plot(z_long_prep_mean[targ_i,plt_start:plt_end,i],z_long_prep_mean[targ_i,plt_start:plt_end,j],alpha = .5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_long_prep_orth_mean = z_long_prep_mean @ projection_matrix\n",
    "\n",
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 0\n",
    "plt_end=200\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for i in range(dim_z):\n",
    "    for j in range(i,dim_z):\n",
    "        ax[i,j].set_prop_cycle('color',prop_cycle)\n",
    "for targ_i in range(len(targets)):\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(z_long_prep_orth_mean[targ_i,plt_start:plt_end,i],alpha = .5)\n",
    "            else:\n",
    "                ax[i,j].plot(z_long_prep_orth_mean[targ_i,plt_start:plt_end,i],z_long_prep_orth_mean[targ_i,plt_start:plt_end,j],alpha = .5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_long_prep_orth_all = z_long_prep_all @ projection_matrix\n",
    "initial_state_proj = initial_state.squeeze() @ projection_matrix\n",
    "plt_start = 0\n",
    "plt_end = 100\n",
    "plt_dim1 = 2\n",
    "plt_dim2 = 1\n",
    "fig, axs = plt.subplots(1, 1, figsize=(1.0, 1.0))\n",
    "for targ_i in range(len(targets)):\n",
    "    axs.plot(\n",
    "        # np.concatenate([initial_state_proj[plt_dim1:plt_dim1+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim1]]),\n",
    "        # np.concatenate([initial_state_proj[plt_dim2:plt_dim2+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim2]]), \n",
    "        z_long_prep_orth_mean[targ_i,plt_start:plt_end,plt_dim1],\n",
    "        z_long_prep_orth_mean[targ_i,plt_start:plt_end,plt_dim2],\n",
    "        color=prop_cycle[targ_i], alpha=0.6)\n",
    "    \n",
    "    for trial in range(z_orth_all.shape[0]):\n",
    "        axs.plot(\n",
    "            # np.concatenate([initial_state_proj[plt_dim1:plt_dim1+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim1]]),\n",
    "            # np.concatenate([initial_state_proj[plt_dim2:plt_dim2+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim2]]), \n",
    "            z_long_prep_orth_all[trial,targ_i,plt_start:plt_end,plt_dim1],\n",
    "            z_long_prep_orth_all[trial,targ_i,plt_start:plt_end,plt_dim2],\n",
    "            color=prop_cycle[targ_i], alpha=0.1, linewidth=0.2)\n",
    "axs.scatter(z_long_prep_orth_mean[:,plt_end-1,plt_dim1],z_long_prep_orth_mean[:,plt_end-1,plt_dim2], \n",
    "            s=20, color=prop_cycle, edgecolor=\"black\", linewidth=0.8, zorder=2)\n",
    "axs.set_xticks([])\n",
    "axs.set_yticks([])\n",
    "axs.set_xlabel(f\"$z_{plt_dim1}$\")\n",
    "axs.set_ylabel(f\"$z_{plt_dim2}$\")\n",
    "\n",
    "xlim = axs.get_xlim()\n",
    "xrange = xlim[1] - xlim[0]\n",
    "axs.set_xlim(xlim[0] - 0.04 * xrange, xlim[1] + 0.04 * xrange)\n",
    "\n",
    "ylim = axs.get_ylim()\n",
    "yrange = ylim[1] - ylim[0]\n",
    "# axs.set_ylim(ylim[0] - 0.04 * yrange, ylim[1] + 0.04 * yrange)\n",
    "\n",
    "axs.set_title(\"Preparatory dynamics\")\n",
    "plt.savefig(\"../../plots/prep_latents_interp.svg\")\n",
    "plt.savefig(\"../../plots/prep_latents_interp.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt_start = 0\n",
    "plt_end = 100\n",
    "plt_dim1 = 2\n",
    "plt_dim2 = 0\n",
    "fig, axs = plt.subplots(1, 1, figsize=(1.0, 1.0))\n",
    "for targ_i in range(len(targets)):\n",
    "    axs.plot(\n",
    "        z_long_prep_mean[targ_i,plt_start:plt_end,plt_dim1],\n",
    "        z_long_prep_mean[targ_i,plt_start:plt_end,plt_dim2],\n",
    "        color=prop_cycle[targ_i], alpha=0.6)\n",
    "    \n",
    "    for trial in range(z_orth_all.shape[0]):\n",
    "        axs.plot(\n",
    "            # np.concatenate([initial_state_proj[plt_dim1:plt_dim1+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim1]]),\n",
    "            # np.concatenate([initial_state_proj[plt_dim2:plt_dim2+1], z_orth_mean[targ_i,plt_start:plt_end,plt_dim2]]), \n",
    "            z_long_prep_all[trial,targ_i,plt_start:plt_end,plt_dim1],\n",
    "            z_long_prep_all[trial,targ_i,plt_start:plt_end,plt_dim2],\n",
    "            color=prop_cycle[targ_i], alpha=0.1, linewidth=0.2)\n",
    "axs.scatter(z_long_prep_mean[:,plt_end-1,plt_dim1],z_long_prep_mean[:,plt_end-1,plt_dim2], \n",
    "            s=20, color=prop_cycle, edgecolor=\"black\", linewidth=0.8, zorder=2)\n",
    "axs.set_xticks([])\n",
    "axs.set_yticks([])\n",
    "axs.set_xlabel(f\"$z_{plt_dim1}$\")\n",
    "axs.set_ylabel(f\"$z_{plt_dim2}$\")\n",
    "\n",
    "xlim = axs.get_xlim()\n",
    "xrange = xlim[1] - xlim[0]\n",
    "axs.set_xlim(xlim[0] - 0.04 * xrange, xlim[1] + 0.04 * xrange)\n",
    "\n",
    "ylim = axs.get_ylim()\n",
    "yrange = ylim[1] - ylim[0]\n",
    "# axs.set_ylim(ylim[0] - 0.04 * yrange, ylim[1] + 0.04 * yrange)\n",
    "\n",
    "axs.set_title(\"Preparatory dynamics\")\n",
    "plt.savefig(\"../../plots/prep_latents_interp.svg\")\n",
    "plt.savefig(\"../../plots/prep_latents_interp.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot 3 of the latents\n",
    "noise_scale = 1\n",
    "targets = [\n",
    "    [-1,0],\n",
    "    [-1,-1],\n",
    "    [0,-1],\n",
    "    [1,-1],\n",
    "    [1,0],\n",
    "    [1,1],\n",
    "    [0,1],\n",
    "    [-1,1],\n",
    "]\n",
    "dur = 100\n",
    "t_on = 0\n",
    "t_of = 12\n",
    "R_z=0.05\n",
    "prop_cycle =[plt.cm.hsv(i) for i in np.arange(0, 1, 1/len(targets))]\n",
    "initial_state = Qzs_filt[:, offset_start, :].mean(axis=0)[None, :, None, None]\n",
    "z0 = torch.tile(torch.tensor(initial_state, dtype=torch.float), (len(targets), 1, 1, 1))\n",
    "\n",
    "n_repeats = 50\n",
    "z_long_reach_all =np.zeros((n_repeats,len(targets),dur,dim_z))\n",
    "for ri in range(n_repeats):\n",
    "    input = np.zeros((len(targets), dur,n_inp))\n",
    "    z=np.zeros((len(targets), dur,dim_z))\n",
    "    for i, target in enumerate(targets):\n",
    "        input[i, t_on:t_of,:]=np.array(target)[None, :]\n",
    "    input = torch.tensor(input, dtype=torch.float).permute(0, 2, 1).cuda()\n",
    "    z = vae.prior.get_latent_time_series(dur,u=input,z0=z0,noise_scale=noise_scale).cpu().numpy()\n",
    "    z_long_reach_all[ri]=z.transpose(0, 2, 1, 3).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_long_reach_mean = z_long_reach_all.mean(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 0\n",
    "plt_end=100\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for i in range(dim_z):\n",
    "    for j in range(i,dim_z):\n",
    "        ax[i,j].set_prop_cycle('color',prop_cycle)\n",
    "for targ_i in range(len(targets)):\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(z_long_reach_mean[targ_i,plt_start:plt_end,i],alpha = .5)\n",
    "            else:\n",
    "                ax[i,j].plot(z_long_reach_mean[targ_i,plt_start:plt_end,i],z_long_reach_mean[targ_i,plt_start:plt_end,j],alpha = .5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_long_reach_orth_mean = z_long_reach_mean @ projection_matrix\n",
    "\n",
    "fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "plt_start = 0\n",
    "plt_end=100\n",
    "fig,ax = plt.subplots(dim_z,dim_z)\n",
    "for i in range(dim_z):\n",
    "    for j in range(i,dim_z):\n",
    "        ax[i,j].set_prop_cycle('color',prop_cycle)\n",
    "for targ_i in range(len(targets)):\n",
    "    for i in range(dim_z):\n",
    "        for j in range(i,dim_z):\n",
    "            if i == j:\n",
    "                ax[i,j].plot(z_long_reach_orth_mean[targ_i,plt_start:plt_end,i],alpha = .5)\n",
    "            else:\n",
    "                ax[i,j].plot(z_long_reach_orth_mean[targ_i,plt_start:plt_end,i],z_long_reach_orth_mean[targ_i,plt_start:plt_end,j],alpha = .5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=(1.2, 1.2), subplot_kw={'projection': '3d'})\n",
    "\n",
    "z_long_reach_orth_all = z_long_reach_all @ projection_matrix\n",
    "initial_state_proj = initial_state.squeeze() @ projection_matrix\n",
    "plt_start = 12\n",
    "plt_end = 100\n",
    "plt_dim1 = 2\n",
    "plt_dim2 = 3\n",
    "plt_dim3 = 4\n",
    "\n",
    "for targ_i in range(len(targets)):\n",
    "    axs.plot(\n",
    "        z_long_reach_orth_mean[targ_i,plt_start:plt_end,plt_dim1],\n",
    "        z_long_reach_orth_mean[targ_i,plt_start:plt_end,plt_dim2],\n",
    "        z_long_reach_orth_mean[targ_i,plt_start:plt_end,plt_dim3],\n",
    "        color=prop_cycle[targ_i], alpha=0.6)\n",
    "        \n",
    "    for trial in range(z_orth_all.shape[0]):\n",
    "        axs.plot(\n",
    "            z_long_reach_orth_all[trial,targ_i,plt_start:plt_end,plt_dim1],\n",
    "            z_long_reach_orth_all[trial,targ_i,plt_start:plt_end,plt_dim2],\n",
    "            z_long_reach_orth_all[trial,targ_i,plt_start:plt_end,plt_dim3],\n",
    "            color=prop_cycle[targ_i], alpha=0.1, linewidth=0.2)\n",
    "        \n",
    "axs.scatter(z_long_reach_orth_mean[:,plt_start,plt_dim1],z_long_reach_orth_mean[:,plt_start,plt_dim2], z_long_reach_orth_mean[:,plt_start,plt_dim3],\n",
    "            s=15, color=prop_cycle, edgecolor=\"black\", linewidth=0.8, zorder=2, alpha=0.8)\n",
    "axs.scatter(z_long_reach_orth_mean[:,plt_end-1,plt_dim1],z_long_reach_orth_mean[:,plt_end-1,plt_dim2], z_long_reach_orth_mean[:,plt_end-1,plt_dim3],\n",
    "            s=15, marker=\"^\", color=prop_cycle, edgecolor=\"black\", linewidth=0.8, zorder=2, alpha=0.8)\n",
    "axs.set_xticks([])\n",
    "axs.set_yticks([])\n",
    "axs.set_zticks([])\n",
    "axs.set_xlabel(f\"$z_{plt_dim1}$\")\n",
    "axs.set_ylabel(f\"$z_{plt_dim2}$\")\n",
    "axs.set_zlabel(f\"$z_{plt_dim3}$\")\n",
    "axs.xaxis.labelpad=-15\n",
    "axs.yaxis.labelpad=-15\n",
    "axs.zaxis.labelpad=-16\n",
    "\n",
    "axs.set_title(\"Movement interpolation\", pad=-10)\n",
    "plt.savefig(\"../../plots/movement_latents_interp.svg\")\n",
    "plt.savefig(\"../../plots/movement_latents_interp.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## todo\n",
    "\n",
    "- spike stats of conditional generation, not inferred\n",
    "- generation of unseen reach (straight down)\n",
    "- projection onto interpretable axes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## decoding projections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "latents = Qzs_filt\n",
    "latents.shape\n",
    "\n",
    "flatten2d = lambda arr: arr.reshape(-1, arr.shape[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_position = np.tile(train_dataset.stim[:,0:1,:].detach().cpu().numpy(), (1, latents.shape[1], 1))\n",
    "target_position.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = DATA_ROOT / \"mc_maze_input\" / \"pos\" / \"eval_target_20ms.h5\"\n",
    "with h5py.File(data_path, \"r\") as h5f:\n",
    "    beh_velocity = h5f[\"mc_maze_20\"][\"train_behavior\"][()]\n",
    "beh_velocity.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beh_position = np.cumsum(beh_velocity, axis=1) / 50\n",
    "beh_position.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beh_accel = np.gradient(beh_velocity, axis=1)\n",
    "beh_accel.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time = np.tile(np.arange(latents.shape[1])[None, :, None], (latents.shape[0], 1, 1))\n",
    "time.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "features = np.concatenate([\n",
    "    target_position,\n",
    "    beh_position,\n",
    "    beh_velocity,\n",
    "    beh_accel,\n",
    "    time\n",
    "], axis=-1)\n",
    "ss_feat = StandardScaler()\n",
    "\n",
    "ss_feat.fit(flatten2d(features))\n",
    "\n",
    "ss_features = ss_feat.transform(flatten2d(features)).reshape(*features.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "ss_lat = StandardScaler()\n",
    "\n",
    "ss_lat.fit(flatten2d(latents))\n",
    "\n",
    "ss_latents = ss_lat.transform(flatten2d(latents)).reshape(*latents.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.linear_model import Ridge\n",
    "\n",
    "decoder = GridSearchCV(Ridge(), param_grid=dict(alpha=np.logspace(-6, -2, 3)))\n",
    "decoder.fit(flatten2d(ss_features), flatten2d(latents))\n",
    "decoder.best_score_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.linear_model import Ridge\n",
    "\n",
    "decoder = GridSearchCV(Ridge(), param_grid=dict(alpha=np.logspace(-6, -2, 3)))\n",
    "decoder.fit(flatten2d(latents), flatten2d(ss_features))\n",
    "decoder.best_score_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder.best_estimator_.coef_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_features = decoder.predict(flatten2d(latents)).reshape(*features.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ss_features[0, :, 0])\n",
    "plt.plot(pred_features[0, :, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ss_features[0, :, 1])\n",
    "plt.plot(pred_features[0, :, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ss_features[0, :, 2])\n",
    "plt.plot(pred_features[0, :, 2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ss_features[0, :, 3])\n",
    "plt.plot(pred_features[0, :, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ss_features[0, :, 4])\n",
    "plt.plot(pred_features[0, :, 4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ss_features[0, :, 5])\n",
    "plt.plot(pred_features[0, :, 5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ss_features[0, :, 6])\n",
    "plt.plot(pred_features[0, :, 6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ss_features[0, :, 7])\n",
    "plt.plot(pred_features[0, :, 7])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ss_features[0, :, 8])\n",
    "plt.plot(pred_features[0, :, 8])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## weird stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_0_range = np.arange(-10, 10, 0.1)\n",
    "f_z_0 = []\n",
    "z_1_5 = torch.zeros(4, dtype=torch.float)\n",
    "for z_0 in z_0_range:\n",
    "    z = torch.cat([torch.tensor([z_0], dtype=torch.float), z_1_5]).to(\"cuda:0\")\n",
    "    f_z_0.append(vae.prior.transition.forward(z=z[None, :, None, None], grad=False).squeeze()[0].item())\n",
    "\n",
    "plt.plot(z_0_range, f_z_0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_0_range = np.arange(-10, 10, 0.1)\n",
    "\n",
    "for i in range(20):\n",
    "    f_z_0 = []\n",
    "    z_1_5 = torch.randn(4, dtype=torch.float) * 3\n",
    "    for z_0 in z_0_range:\n",
    "        z = torch.cat([torch.tensor([z_0], dtype=torch.float), z_1_5]).to(\"cuda:0\")\n",
    "        f_z_0.append(vae_orth.prior.transition.forward(z=z[None, :, None, None], grad=False).squeeze()[0].item())\n",
    "\n",
    "    plt.plot(z_0_range, f_z_0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_0_range = np.arange(-10, 10, 0.1)\n",
    "\n",
    "for i in range(20):\n",
    "    f_z_1 = []\n",
    "    z_1_5 = torch.randn(4, dtype=torch.float) * 3\n",
    "    for z_0 in z_0_range:\n",
    "        z = torch.cat([torch.tensor([z_0], dtype=torch.float), z_1_5]).to(\"cuda:0\")\n",
    "        f_z_1.append(vae_orth.prior.transition.forward(z=z[None, :, None, None], grad=False).squeeze()[1].item())\n",
    "\n",
    "    plt.plot(z_0_range, f_z_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "virnn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
