{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analyzing results for Spatial Navigation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import matplotlib as mpl\n",
    "import matplotlib.font_manager as fm\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import pathlib\n",
    "import scipy.stats as stats\n",
    "from sklearn.decomposition import PCA\n",
    "import sys\n",
    "import torch\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "from models import rnn\n",
    "from tasks import spatial_navigation\n",
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if \"Arial\" not in fm.get_font_names():\n",
    "#    font_path = pathlib.Path.home() / \"fonts\" / \"arial.ttf\"  # Set to correct path\n",
    "#    fm.fontManager.addfont(font_path)\n",
    "\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"  # \"Arial\"\n",
    "plt.rcParams[\"font.size\"] = 10\n",
    "plt.rcParams[\"axes.linewidth\"] = 1.2\n",
    "plt.rcParams[\"xtick.major.width\"] = 1.2\n",
    "plt.rcParams[\"ytick.major.width\"] = 1.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "active = \"#2ca7c5\"\n",
    "quiescent = \"#ee3233\"\n",
    "\n",
    "active_colors = [\n",
    "    \"#ffffff\",\n",
    "    \"#d4edf3\",\n",
    "    \"#aadbe7\",\n",
    "    \"#80cadc\",\n",
    "    \"#56b8d0\",\n",
    "    \"#2ca7c5\",\n",
    "    \"#2796b1\",\n",
    "    \"#1e7489\",\n",
    "    \"#165362\",\n",
    "    \"#0d323b\",\n",
    "]\n",
    "lcmap_active = mpl.colors.LinearSegmentedColormap.from_list(\n",
    "    \"lcmap_active\", active_colors\n",
    ")\n",
    "norm = mpl.colors.Normalize(vmin=0, vmax=1.5)\n",
    "lcmap_active = mpl.cm.ScalarMappable(norm=norm, cmap=lcmap_active)\n",
    "lcmap_active.set_array([])\n",
    "\n",
    "quiescent_colors = [\n",
    "    \"#ffffff\",\n",
    "    \"#fbd6d6\",\n",
    "    \"#f8adad\",\n",
    "    \"#f48484\",\n",
    "    \"#f15a5b\",\n",
    "    \"#ee3233\",\n",
    "    \"#d62d2d\",\n",
    "    \"#a62323\",\n",
    "    \"#771919\",\n",
    "    \"#470f0f\",\n",
    "]\n",
    "lcmap_quiescent = mpl.colors.LinearSegmentedColormap.from_list(\n",
    "    \"lcmap_quiescent\", quiescent_colors\n",
    ")\n",
    "norm = mpl.colors.Normalize(vmin=0, vmax=1.5)\n",
    "lcmap_quiescent = mpl.cm.ScalarMappable(norm=norm, cmap=lcmap_quiescent)\n",
    "lcmap_quiescent.set_array([])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "unmask_every = 6\n",
    "epoch = 15_000  # epoch = 5000 #15000 #2500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = pathlib.Path(\"../results\")\n",
    "noise = 0.0001\n",
    "config_name = f\"noisy_unbiased_unmask_every_{unmask_every}_{noise}\"\n",
    "seeds = [\n",
    "    int(path.split(\"_\")[-1])\n",
    "    for path in os.listdir(results / \"spatial_navigation\")\n",
    "    if config_name in path\n",
    "]\n",
    "num_seeds = len(seeds)\n",
    "seed_dirs = [results / \"spatial_navigation\" / f\"{config_name}_{seed}\" for seed in seeds]\n",
    "figures = results / \"spatial_navigation\" / \"figures\"\n",
    "pathlib.Path.mkdir(figures, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# idx = seeds.index(15)\n",
    "# seed_dirs = seed_dirs[idx:idx+1]\n",
    "# seeds = seeds[idx:idx+1]\n",
    "seeds, seed_dirs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizations "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_metrics = [json.load(open(d / \"test_metrics.json\", \"r\")) for d in seed_dirs]\n",
    "start, end, skip = 100, epoch + 1, 100\n",
    "test_losses = np.array(\n",
    "    [[f[str(i)][\"loss\"] for i in range(start, end, skip)] for f in test_metrics]\n",
    ")\n",
    "test_posmse = np.array(\n",
    "    [[f[str(i)][\"pos_mse\"] for i in range(start, end, skip)] for f in test_metrics]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "losses_mean, losses_std = test_losses.mean(axis=0), test_losses.std(axis=0)\n",
    "posmse_mean, posmse_std = test_posmse.mean(axis=0), test_posmse.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(1.5, 1.5))\n",
    "\n",
    "ax.spines[[\"right\", \"top\"]].set_visible(False)\n",
    "ax.set_box_aspect(1)\n",
    "ax.errorbar(\n",
    "    np.arange(start, end, skip), losses_mean, yerr=losses_std, fmt=\"-\", c=\"gray\"\n",
    ")\n",
    "ax.set_xlabel(\"batches\")\n",
    "ax.set_ylabel(\"test loss\", color=\"gray\")\n",
    "ax.set_yticks(np.linspace(1e-5, 5e-5, 5))\n",
    "ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "ax.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n",
    "ax.yaxis.get_offset_text().set_fontsize(8)\n",
    "\n",
    "ax2 = ax.twinx()\n",
    "ax2.spines[[\"left\", \"top\"]].set_visible(False)\n",
    "ax2.set_box_aspect(1)\n",
    "ax2.errorbar(\n",
    "    np.arange(start, end, skip), posmse_mean, yerr=posmse_std, fmt=\"-\", c=\"black\"\n",
    ")\n",
    "ax2.set_ylabel(\"position decoding error\", color=\"black\")\n",
    "ax2.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "ax2.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_metrics = [json.load(open(d / \"train_metrics.json\", \"r\")) for d in seed_dirs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 0\n",
    "fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 3))\n",
    "axs[0, 0].plot([v[\"loss\"] for (k, v) in train_metrics[i].items()][100:])\n",
    "axs[0, 1].plot([v[\"pos_mse\"] for (k, v) in train_metrics[i].items()][100:])\n",
    "axs[1, 0].plot([v[\"loss\"] for (k, v) in test_metrics[i].items()])\n",
    "axs[1, 1].plot([v[\"pos_mse\"] for (k, v) in test_metrics[i].items()])\n",
    "fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig.savefig(figures / f\"{config_name}-loss.pdf\", bbox_inches=\"tight\", pad_inches=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Population activity dimensionality (PCs vs cumulative explained variance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def momentum(self, lambda_v, tau_a, b_a, x, init_state=None):\n",
    "    batch_size, timesteps = x.shape[0], x.shape[1]\n",
    "\n",
    "    if init_state is not None:\n",
    "        self.h = self.encoder(init_state.reshape(batch_size, self.n_init))\n",
    "    else:\n",
    "        self.h = torch.zeros(batch_size, self.n_rec, device=self.device)\n",
    "    self.v = torch.zeros_like(self.h)\n",
    "    self.c = torch.zeros_like(self.h)\n",
    "\n",
    "    self.h_1t = torch.zeros(batch_size, timesteps, self.n_rec, device=self.device)\n",
    "    self.v_1t = torch.zeros_like(self.h_1t)\n",
    "    self.c_1t = torch.zeros_like(self.h_1t)\n",
    "    self.y_1t = torch.zeros(batch_size, timesteps, self.n_out, device=self.device)\n",
    "\n",
    "    for t in range(timesteps):\n",
    "        x_t = x[:, t, :].reshape(batch_size, self.n_in)\n",
    "        noise_in = torch.rand_like(x_t, device=self.device)\n",
    "        self.u = self.w_rec(self.h) + self.w_in(x_t + self.sigma_in * noise_in)\n",
    "\n",
    "        noise_rec = torch.rand_like(self.h, device=self.device)\n",
    "        next_h = (\n",
    "            (1 - self.dt / self.tau) * self.h\n",
    "            + (self.dt / self.tau) * self.activation(self.u)\n",
    "            + self.sigma_rec * noise_rec\n",
    "        )\n",
    "        # delta_h = next_h - self.h - self.c\n",
    "        delta_h = next_h - self.h\n",
    "        # self.c = (1 / tau_a) * (self.c + b_a * self.h)\n",
    "        self.v = (1 - lambda_v) * self.v + delta_h\n",
    "        # self.h += self.v\n",
    "        self.h += self.v - self.c\n",
    "        self.c = (1 / tau_a) * (self.c + b_a * self.h)\n",
    "\n",
    "        noise_out = torch.randn(batch_size, self.n_out, device=self.device)\n",
    "        self.y = self.w_out(self.h) + self.sigma_out * noise_out\n",
    "\n",
    "        self.h_1t[:, t, :] = self.h\n",
    "        self.v_1t[:, t, :] = self.v\n",
    "        self.c_1t[:, t, :] = self.c\n",
    "        self.y_1t[:, t, :] = self.y\n",
    "\n",
    "    return self.h_1t, self.y_1t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ev_active = []\n",
    "ev_quiescent = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for s_i, seed in enumerate(seeds):\n",
    "    utils.set_random_seeds(seed)\n",
    "\n",
    "    model = torch.load(seed_dirs[s_i] / f\"model_{epoch}.pt\", map_location=\"cpu\")\n",
    "    model.set_device(\"cpu\")\n",
    "    model.eval()\n",
    "\n",
    "    task = model.task\n",
    "    task.batch_size = 1000\n",
    "\n",
    "    test_data = task.get_test_batch()\n",
    "    t_quiescent = 200\n",
    "    quiescent_inputs = torch.zeros(task.batch_size, t_quiescent, 2)\n",
    "\n",
    "    h_active, _ = model(test_data[\"data\"], test_data[\"init_state\"])\n",
    "    # h_active, _ = momentum(model, 1, 100, 0, test_data[\"data\"], test_data[\"init_state\"])\n",
    "    h_active = h_active.cpu().detach().numpy().reshape(-1, model.n_rec)\n",
    "    model.sigma_rec *= np.sqrt(2)\n",
    "    # h_quiescent, _ = model(quiescent_inputs, test_data[\"init_state\"])\n",
    "    h_quiescent, _ = momentum(\n",
    "        model,\n",
    "        x=test_data[\"data\"],\n",
    "        init_state=test_data[\"init_state\"],\n",
    "        lambda_v=1,\n",
    "        tau_a=100,\n",
    "        b_a=0,\n",
    "    )\n",
    "    h_quiescent = h_quiescent.cpu().detach().numpy().reshape(-1, model.n_rec)\n",
    "\n",
    "    pca = PCA()\n",
    "    pca.fit(h_active)\n",
    "    ev_active.append(pca.explained_variance_ratio_)\n",
    "    pca.fit(h_quiescent)\n",
    "    ev_quiescent.append(pca.explained_variance_ratio_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ev_active = np.array(ev_active)\n",
    "ev_quiescent = np.array(ev_quiescent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(1.5, 1.5))\n",
    "ax.spines[[\"right\", \"top\"]].set_visible(False)\n",
    "ax.set_box_aspect(1)\n",
    "\n",
    "ax.errorbar(\n",
    "    np.arange(1, 11),\n",
    "    ev_active.cumsum(axis=-1).mean(axis=0)[:10],\n",
    "    yerr=ev_active.cumsum(axis=-1).std(axis=0)[:10],\n",
    "    fmt=\"-\",\n",
    "    c=active,\n",
    "    label=\"active\",\n",
    ")\n",
    "ax.errorbar(\n",
    "    np.arange(1, 11),\n",
    "    ev_quiescent.cumsum(axis=-1).mean(axis=0)[:10],\n",
    "    yerr=ev_quiescent.cumsum(axis=-1).std(axis=0)[:10],\n",
    "    fmt=\"-\",\n",
    "    c=quiescent,\n",
    "    label=\"quiescent\",\n",
    ")\n",
    "\n",
    "ax.set_xlabel(\"# of PCs\")\n",
    "ax.set_ylabel(\"explained variance\")\n",
    "ax.set_xticks(np.arange(1, 11))\n",
    "ax.set_yticks(np.arange(0.4, 1.0, 0.1))\n",
    "ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "ax.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n",
    "\n",
    "leg = ax.legend(loc=\"best\", frameon=False)\n",
    "\n",
    "# for handle, text in zip(leg.legend_handles, leg.get_texts()):\n",
    "for handle, text in zip(leg.legendHandles, leg.get_texts()):\n",
    "    # text.set_color(handle.get_color())\n",
    "    text.set_color(handle.get_color().flatten())\n",
    "    handle.set_visible(False)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig.savefig(figures / f\"{config_name}-explained_variance.pdf\", bbox_inches=\"tight\", pad_inches=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Neural activity PCs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "quiescence = \"same\"\n",
    "# quiescence = \"scaled\"\n",
    "lambda_v = 1 * 0.7\n",
    "tau_a = 100  # * 2\n",
    "b_a = 0.1 * 0  # 3\n",
    "# One of:\n",
    "#   - \"scaled\" (quiescent_noise = np.sqrt(2) * active_noise)\n",
    "#   - \"same\" (quiescent_noise = active_noise)\n",
    "#   - \"absolute\" (quiescent_noise = np.sqrt(2 * noise) where noise is defined earlier e.g. 0.0001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 15\n",
    "utils.set_random_seeds(seed)\n",
    "\n",
    "# model = torch.load(seed_dirs[seed] / f\"model_{epoch}.pt\", map_location=\"cpu\")\n",
    "model = torch.load(\n",
    "    seed_dirs[seeds.index(seed)] / f\"model_{epoch}.pt\", map_location=\"cpu\"\n",
    ")\n",
    "model.set_device(\"cpu\")\n",
    "model.eval()\n",
    "\n",
    "task = model.task\n",
    "task.batch_size = 200\n",
    "\n",
    "test_data = task.get_test_batch()\n",
    "t_quiescent = 1000\n",
    "quiescent_inputs = torch.zeros(task.batch_size, t_quiescent, 2)\n",
    "\n",
    "# h_active = model(test_data[\"data\"], test_data[\"init_state\"])\n",
    "# h_active = momentum(model, lambda_v, tau_a, b_a, test_data[\"data\"], test_data[\"init_state\"])\n",
    "h_active = momentum(model, 1, 1, 0, test_data[\"data\"], test_data[\"init_state\"])\n",
    "\n",
    "if quiescence == \"scaled\":\n",
    "    model.sigma_rec *= np.sqrt(2)\n",
    "elif quiescence == \"absolute\":\n",
    "    model.sigma_rec = np.sqrt(2 * noise)\n",
    "elif quiescence == \"same\":\n",
    "    pass\n",
    "\n",
    "# h_quiescent = model(quiescent_inputs, test_data[\"init_state\"])\n",
    "h_quiescent = momentum(\n",
    "    model, lambda_v, tau_a, b_a, quiescent_inputs, test_data[\"init_state\"]\n",
    ")\n",
    "\n",
    "h_active_ = h_active[0].cpu().detach().numpy().reshape(-1, model.n_rec)\n",
    "h_quiescent_ = h_quiescent[0].cpu().detach().numpy().reshape(-1, model.n_rec)\n",
    "\n",
    "pca = PCA()\n",
    "h_active_pca = pca.fit_transform(h_active_)\n",
    "h_quiescent_pca = pca.transform(h_quiescent_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if True:\n",
    "    active_xy = task.place_cells.decode_pos(h_active[1])\n",
    "    quiescent_xy = task.place_cells.decode_pos(h_quiescent[1])\n",
    "else:\n",
    "    normalize = lambda x: x / x.abs().sum(-1)[..., None]\n",
    "    active_xy = normalize(h_active[1] ** 3).detach() @ task.place_cells.us\n",
    "    quiescent_xy = normalize(h_quiescent[1] ** 3).detach() @ task.place_cells.us\n",
    "active_d = (active_xy**2).sum(axis=-1) ** 0.5\n",
    "quiescent_d = (quiescent_xy**2).sum(axis=-1) ** 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(7.5, 1.5))\n",
    "ax.set_box_aspect(1)\n",
    "\n",
    "skip_traj = 2\n",
    "shift_time = 7\n",
    "skip_time = 20\n",
    "t_task = task.sequence_length\n",
    "for i in range(0, task.batch_size, skip_traj):\n",
    "    plt.scatter(\n",
    "        h_active_pca[i * t_task + shift_time : (i + 1) * t_task : skip_time, 0],\n",
    "        h_active_pca[i * t_task + shift_time : (i + 1) * t_task : skip_time, 1],\n",
    "        c=lcmap_active.to_rgba(active_d[i, shift_time::skip_time]),\n",
    "    )\n",
    "\n",
    "ax.grid(visible=False)\n",
    "ax.set_xlabel(\"PC-1\")\n",
    "ax.set_ylabel(\"PC-2\")\n",
    "ax.spines[[\"right\", \"top\"]].set_visible(False)\n",
    "ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "ax.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n",
    "ax.set_xticks([-1, 0, 1])\n",
    "ax.set_yticks([-1, 0, 1])\n",
    "ax.set_xlim(-1.55, 1.55)\n",
    "ax.set_ylim(-1.55, 1.55)\n",
    "\n",
    "cb = plt.colorbar(mappable=lcmap_active, ax=ax)\n",
    "cb.ax.tick_params(labelsize=8)\n",
    "cb.set_label(label=\"distance from center in\\n(x, y)-space\", labelpad=10, fontsize=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig.savefig(figures / f\"{config_name}_{seed}-pca_active.pdf\", bbox_inches=\"tight\", pad_inches=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(7.5, 1.5))\n",
    "ax.set_box_aspect(1)\n",
    "\n",
    "skip_traj = 2\n",
    "shift_time = 7\n",
    "skip_time = 20\n",
    "for i in range(0, task.batch_size, skip_traj):\n",
    "    plt.scatter(\n",
    "        h_quiescent_pca[\n",
    "            i * t_quiescent + shift_time : (i + 1) * t_quiescent : skip_time, 0\n",
    "        ],\n",
    "        h_quiescent_pca[\n",
    "            i * t_quiescent + shift_time : (i + 1) * t_quiescent : skip_time, 1\n",
    "        ],\n",
    "        c=lcmap_quiescent.to_rgba(quiescent_d[i, shift_time::skip_time]),\n",
    "    )\n",
    "\n",
    "ax.grid(visible=False)\n",
    "ax.set_xlabel(\"PC-1\")\n",
    "ax.set_ylabel(\"PC-2\")\n",
    "ax.spines[[\"right\", \"top\"]].set_visible(False)\n",
    "ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "ax.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n",
    "ax.set_xticks([-1, 0, 1])\n",
    "ax.set_yticks([-1, 0, 1])\n",
    "ax.set_xlim(-1.55, 1.55)\n",
    "ax.set_ylim(-1.55, 1.55)\n",
    "\n",
    "cb = plt.colorbar(mappable=lcmap_quiescent, ax=ax)\n",
    "cb.ax.tick_params(labelsize=8)\n",
    "cb.ax.yaxis.set_ticks_position(\"left\")\n",
    "cb.set_label(label=\"distance from center in\\n(x, y)-space\", labelpad=10, fontsize=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig.savefig(\n",
    "#    figures / f\"{config_name}_{seed}-pca_quiescent_{quiescence}.pdf\", bbox_inches=\"tight\", pad_inches=0\n",
    "# )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Decoded outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, axes = plt.subplots(nrows=1, ncols=2, layout=\"compressed\", figsize=(3, 1.5))\n",
    "fig, axes = plt.subplots(nrows=1, ncols=2, layout=\"constrained\", figsize=(3, 1.5))\n",
    "\n",
    "ax = axes[0]\n",
    "ax.set_box_aspect(1)\n",
    "traj_skip = 8\n",
    "\n",
    "for i in range(0, task.batch_size, traj_skip):\n",
    "    ax.plot(active_xy[i, :, 0], active_xy[i, :, 1], c=active, alpha=0.5)\n",
    "    ax.scatter(\n",
    "        active_xy[i, 0, 0], active_xy[i, 0, 1], c=\"k\", label=\"start\", s=10, zorder=3\n",
    "    )\n",
    "    ax.scatter(\n",
    "        active_xy[i, -1, 0],\n",
    "        active_xy[i, -1, 1],\n",
    "        c=\"k\",\n",
    "        label=\"end\",\n",
    "        marker=\"^\",\n",
    "        s=10,\n",
    "        zorder=3,\n",
    "    )\n",
    "\n",
    "ax.grid(visible=False)\n",
    "ax.set_xlabel(\"x\")\n",
    "ax.set_ylabel(\"y\")\n",
    "ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "ax.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n",
    "ax.set_xlim(-1.15, 1.15)\n",
    "ax.set_ylim(-1.15, 1.15)\n",
    "ax.set_xticks([-1, 0, 1])\n",
    "ax.set_yticks([-1, 0, 1])\n",
    "\n",
    "ax = axes[1]\n",
    "ax.set_box_aspect(1)\n",
    "traj_skip = 4\n",
    "\n",
    "for i in range(0, task.batch_size, traj_skip):\n",
    "    ax.plot(quiescent_xy[i, :, 0], quiescent_xy[i, :, 1], c=quiescent, alpha=0.5)\n",
    "    ax.scatter(\n",
    "        quiescent_xy[i, 0, 0],\n",
    "        quiescent_xy[i, 0, 1],\n",
    "        c=\"k\",\n",
    "        label=\"start\",\n",
    "        s=10,\n",
    "        zorder=3,\n",
    "    )\n",
    "    ax.scatter(\n",
    "        quiescent_xy[i, -1, 0],\n",
    "        quiescent_xy[i, -1, 1],\n",
    "        c=\"k\",\n",
    "        label=\"end\",\n",
    "        marker=\"^\",\n",
    "        s=10,\n",
    "        zorder=3,\n",
    "    )\n",
    "\n",
    "ax.grid(visible=False)\n",
    "ax.set_xlabel(\"x\")\n",
    "ax.set_ylabel(\"y\")\n",
    "ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "ax.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n",
    "ax.set_xlim(-1.15, 1.15)\n",
    "ax.set_ylim(-1.15, 1.15)\n",
    "ax.set_xticks([-1, 0, 1])\n",
    "ax.set_yticks([-1, 0, 1])\n",
    "\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "by_label = dict(zip(labels, handles))\n",
    "fig.legend(\n",
    "    by_label.values(),\n",
    "    by_label.keys(),\n",
    "    frameon=False,\n",
    "    bbox_to_anchor=(0.49, -0.07),\n",
    "    loc=\"lower center\",\n",
    "    borderaxespad=0,\n",
    "    ncol=2,\n",
    "    handletextpad=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for seq in test_data[\"target_pos\"]:\n",
    "    plt.plot(seq[:, 0], seq[:, 1], c=\"C0\", alpha=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_place_cells(acts):\n",
    "    plt.figure(figsize=(3, 3))\n",
    "    plt.scatter(task.place_cells.us[:, 0], task.place_cells.us[:, 1], c=acts)\n",
    "    plt.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_place_cells(test_data[\"targets\"][0, 60])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_place_cells(h_active[1][0, 60].detach())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "foo = task.place_cells.decode_pos(test_data[\"targets\"])  # , k=3)\n",
    "for seq in foo:\n",
    "    plt.plot(seq[:, 0], seq[:, 1], c=\"C0\", alpha=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "normalize = lambda x: x / x.abs().sum(-1)[..., None]\n",
    "foo2 = normalize(test_data[\"targets\"] ** 3) @ task.place_cells.us\n",
    "for seq in foo2:\n",
    "    plt.plot(seq[:, 0], seq[:, 1], c=\"C0\", alpha=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in np.arange(10):\n",
    "    plt.plot(\n",
    "        test_data[\"target_pos\"][i, :, 0], test_data[\"target_pos\"][i, :, 1], c=f\"C{i}\"\n",
    "    )\n",
    "    # plt.plot(foo[i,:,0], foo[i,:,1], c=f'C{i}', linestyle='--')\n",
    "    plt.plot(foo2[i, :, 0], foo2[i, :, 1], c=f\"C{i}\", linestyle=\":\")\n",
    "\n",
    "err1 = (foo - test_data[\"target_pos\"]).square().sum(-1).sqrt().mean()\n",
    "err2 = (foo2 - test_data[\"target_pos\"]).square().sum(-1).sqrt().mean()\n",
    "print(err1, err2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig.savefig(figures / f\"{config_name}_{seed}-behavior-q_{quiescence}.pdf\", bbox_inches=\"tight\", pad_inches=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Density plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(3, 3))\n",
    "\n",
    "# Active\n",
    "X = active_xy.reshape(-1, 2).cpu().detach().numpy()\n",
    "x, y = X[:, 0], X[:, 1]\n",
    "xmin, xmax = -1.25, 1.25\n",
    "ymin, ymax = -1.25, 1.25\n",
    "xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]\n",
    "\n",
    "positions = np.vstack([xx.ravel(), yy.ravel()])\n",
    "values = np.vstack([x, y])\n",
    "kernel_active = stats.gaussian_kde(values)\n",
    "f = np.reshape(kernel_active(positions).T, xx.shape)\n",
    "\n",
    "ax = axes[0]\n",
    "ax.set_xlim(xmin, xmax)\n",
    "ax.set_ylim(ymin, ymax)\n",
    "cfset = ax.contourf(xx, yy, f, cmap=\"coolwarm\", levels=np.linspace(0.0, 1.2, 9))\n",
    "ax.imshow(np.rot90(f), cmap=\"coolwarm\", extent=[xmin, xmax, ymin, ymax])\n",
    "ax.set_xlabel(\"x\")\n",
    "ax.set_ylabel(\"y\")\n",
    "ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "ax.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n",
    "ax.set_xticks([-1, 0, 1])\n",
    "ax.set_yticks([-1, 0, 1])\n",
    "\n",
    "# Quiescent\n",
    "traj_skip = 5\n",
    "X = quiescent_xy[::traj_skip, :, :].reshape(-1, 2).cpu().detach().numpy()\n",
    "x, y = X[:, 0], X[:, 1]\n",
    "xmin, xmax = -1.25, 1.25\n",
    "ymin, ymax = -1.25, 1.25\n",
    "xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]\n",
    "\n",
    "positions = np.vstack([xx.ravel(), yy.ravel()])\n",
    "values = np.vstack([x, y])\n",
    "kernel = stats.gaussian_kde(values)\n",
    "f = np.reshape(kernel(positions).T, xx.shape)\n",
    "\n",
    "ax = axes[1]\n",
    "ax.set_xlim(xmin, xmax)\n",
    "ax.set_ylim(ymin, ymax)\n",
    "cfset = ax.contourf(xx, yy, f, cmap=\"coolwarm\", levels=np.linspace(0.0, 1.2, 9))\n",
    "ax.imshow(np.rot90(f), cmap=\"coolwarm\", extent=[xmin, xmax, ymin, ymax])\n",
    "ax.set_xlabel(\"x\")\n",
    "ax.set_ylabel(\"y\")\n",
    "ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "ax.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n",
    "ax.set_xticks([-1, 0, 1])\n",
    "ax.set_yticks([-1, 0, 1])\n",
    "\n",
    "cb = fig.colorbar(\n",
    "    cfset,\n",
    "    ax=axes.ravel().tolist(),\n",
    "    orientation=\"horizontal\",\n",
    "    location=\"bottom\",\n",
    "    label=\"density\",\n",
    ")\n",
    "cb.ax.tick_params(labelsize=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig.savefig(\n",
    "#    figures / f\"{config_name}_{seed}-kde_output-q_{quiescence}.pdf\", bbox_inches=\"tight\", pad_inches=0\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(3, 3))\n",
    "\n",
    "# Active\n",
    "X = h_active_pca[:, :2]\n",
    "x, y = X[:, 0], X[:, 1]\n",
    "delta_x = (max(x) - min(x)) / 10\n",
    "delta_y = (max(y) - min(y)) / 10\n",
    "xmin = min(x) - delta_x\n",
    "xmax = max(x) + delta_x\n",
    "ymin = min(y) - delta_y\n",
    "ymax = max(y) + delta_y\n",
    "xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]\n",
    "\n",
    "positions = np.vstack([xx.ravel(), yy.ravel()])\n",
    "values = np.vstack([x, y])\n",
    "kernel = stats.gaussian_kde(values)\n",
    "f = np.reshape(kernel(positions).T, xx.shape)\n",
    "\n",
    "ax = axes[0]\n",
    "ax.set_xlim(xmin, xmax)\n",
    "ax.set_ylim(ymin, ymax)\n",
    "cfset = ax.contourf(xx, yy, f, cmap=\"coolwarm\", levels=np.linspace(0.0, 1.6, 9))\n",
    "ax.imshow(np.rot90(f), cmap=\"coolwarm\", extent=[xmin, xmax, ymin, ymax])\n",
    "ax.set_xlabel(\"PC-1\")\n",
    "ax.set_ylabel(\"PC-2\")\n",
    "ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "ax.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n",
    "ax.set_xticks([-1, 0, 1])\n",
    "ax.set_yticks([-1, 0, 1])\n",
    "\n",
    "# Quiescent\n",
    "traj_skip = 5\n",
    "X = h_quiescent_pca[::traj_skip, :2]\n",
    "x, y = X[:, 0], X[:, 1]\n",
    "xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]\n",
    "\n",
    "positions = np.vstack([xx.ravel(), yy.ravel()])\n",
    "values = np.vstack([x, y])\n",
    "kernel = stats.gaussian_kde(values)\n",
    "f = np.reshape(kernel(positions).T, xx.shape)\n",
    "\n",
    "ax = axes[1]\n",
    "ax.set_xlim(xmin, xmax)\n",
    "ax.set_ylim(ymin, ymax)\n",
    "cfset = ax.contourf(xx, yy, f, cmap=\"coolwarm\", levels=np.linspace(0.0, 1.6, 9))\n",
    "ax.imshow(np.rot90(f), cmap=\"coolwarm\", extent=[xmin, xmax, ymin, ymax])\n",
    "ax.set_xlabel(\"PC-1\")\n",
    "ax.set_ylabel(\"PC-2\")\n",
    "ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "ax.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n",
    "ax.set_xticks([-1, 0, 1])\n",
    "ax.set_yticks([-1, 0, 1])\n",
    "\n",
    "cb = fig.colorbar(\n",
    "    cfset,\n",
    "    ax=axes.ravel().tolist(),\n",
    "    orientation=\"horizontal\",\n",
    "    location=\"bottom\",\n",
    "    label=\"density\",\n",
    ")\n",
    "cb.ax.tick_params(labelsize=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig.savefig(figures / f\"{config_name}_{seed}-kde_pca-q_{quiescence}.pdf\", bbox_inches=\"tight\", pad_inches=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = test_data[\"target_pos\"].reshape(-1, 2).cpu().detach().numpy()\n",
    "x, y = X[:, 0], X[:, 1]\n",
    "xmin, xmax = -1.25, 1.25\n",
    "ymin, ymax = -1.25, 1.25\n",
    "xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]\n",
    "\n",
    "positions = np.vstack([xx.ravel(), yy.ravel()])\n",
    "values = np.vstack([x, y])\n",
    "kernel_test = stats.gaussian_kde(values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kls = []\n",
    "# distances = []\n",
    "for l_v in [1, 0.9, 0.8, 0.7]:\n",
    "    for b_a in [0, 0.5, 1]:\n",
    "        h_q = momentum(\n",
    "            model, l_v, tau_a, b_a, quiescent_inputs, test_data[\"init_state\"]\n",
    "        )\n",
    "        # quiescent_xy = task.place_cells.decode_pos(h_quiescent[1])\n",
    "        quiescent_xy = normalize(h_quiescent[1] ** 3).detach() @ task.place_cells.us\n",
    "        # avg_dist = quiescent_xy.diff(dim=1).square().sum(-1).sum(-1).mean()\n",
    "        # distances.append(avg_dist)\n",
    "\n",
    "        # if False:\n",
    "        # Quiescent\n",
    "        traj_skip = 5\n",
    "        X = quiescent_xy[::traj_skip, :, :].reshape(-1, 2).cpu().detach().numpy()\n",
    "        x, y = X[:, 0], X[:, 1]\n",
    "        xmin, xmax = -1.25, 1.25\n",
    "        ymin, ymax = -1.25, 1.25\n",
    "        xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]\n",
    "\n",
    "        positions = np.vstack([xx.ravel(), yy.ravel()])\n",
    "        values = np.vstack([x, y])\n",
    "        kernel = stats.gaussian_kde(values)\n",
    "        kls.append(utils.kl_divergence(kernel_test, kernel, n=5_000))  # kernel_active\n",
    "\n",
    "        print(f\"{l_v=},{tau_a=}, {b_a=}    {kls[-1] :.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(kls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(kls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(kls)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Average output variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "variance_files = [\n",
    "    json.load(\n",
    "        open(results / \"spatial_navigation\" / f\"variance_{noise}_{seed}.json\", \"r\")\n",
    "    )\n",
    "    for seed in range(num_seeds)\n",
    "]\n",
    "variance = np.array([list(f.values()) for f in variance_files])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(3, 1.5))\n",
    "c = \"black\"\n",
    "ax.boxplot(\n",
    "    variance,\n",
    "    labels=[\n",
    "        \"noiseless\\nunbiased\",\n",
    "        \"noiseless\\nbiased\",\n",
    "        \"noisy\\nunbiased\",\n",
    "        \"noisy\\nbiased\",\n",
    "    ],\n",
    "    patch_artist=True,\n",
    "    boxprops=dict(facecolor=quiescent, color=c),\n",
    "    capprops=dict(color=c, linewidth=0),\n",
    "    whiskerprops=dict(color=c),\n",
    "    flierprops=dict(color=c, markeredgecolor=c),\n",
    "    medianprops=dict(color=c),\n",
    ")\n",
    "ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "ax.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n",
    "ax.tick_params(axis=\"x\", bottom=False)\n",
    "ax.spines[[\"right\", \"top\"]].set_visible(False)\n",
    "ax.set_ylabel(\"avg. variance\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig(figures / f\"output_variance_{noise}.pdf\", bbox_inches=\"tight\", pad_inches=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### KL divergence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kl_files = [\n",
    "    json.load(open(results / \"spatial_navigation\" / f\"kl_{noise}_{seed}.json\", \"r\"))\n",
    "    for seed in seeds\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def outer_order(seed):\n",
    "    return [\n",
    "        f\"noisy_unbiased_0.0001_{seed}\",\n",
    "        f\"no-noise_unbiased_{seed}\",\n",
    "        \"uniform\",\n",
    "        f\"noisy_biased_0.0001_{seed}\",\n",
    "        f\"no-noise_biased_{seed}\",\n",
    "        \"random\",\n",
    "    ]\n",
    "\n",
    "\n",
    "kl_files = [\n",
    "    {k: kl_files[seed][k] for k in outer_order(seed)} for seed in range(num_seeds)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inner_order(seed):\n",
    "    return [\n",
    "        f\"noisy_unbiased_0.0001_{seed}\",\n",
    "        f\"no-noise_unbiased_{seed}\",\n",
    "        f\"noisy_biased_0.0001_{seed}\",\n",
    "        f\"no-noise_biased_{seed}\",\n",
    "    ]\n",
    "\n",
    "\n",
    "kl_files_n = []\n",
    "for seed in range(5):\n",
    "    new_dict = {}\n",
    "    for k, v in kl_files[seed].items():\n",
    "        new_dict[k] = {k_: v[k_] for k_ in inner_order(seed)}\n",
    "    kl_files_n.append(new_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kl = np.array(\n",
    "    [[list(f.values()) for f in kl_files_n[seed].values()] for seed in range(5)]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kl_mean = kl.mean(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(3, 1.5), dpi=600)\n",
    "\n",
    "im = ax.matshow(kl_mean.T, cmap=\"coolwarm\", interpolation=\"nearest\")\n",
    "\n",
    "ax.set_xticks(\n",
    "    range(6),\n",
    "    [\n",
    "        r\"$\\mathregular{U}^{\\sigma}$\",\n",
    "        \"U\",\n",
    "        r\"$\\mathcal{U}$\",\n",
    "        r\"$\\mathregular{B}^{\\sigma}$\",\n",
    "        \"B\",\n",
    "        r\"$\\mathregular{R}^{\\sigma}$\",\n",
    "    ],\n",
    ")\n",
    "ax.set_yticks(\n",
    "    range(4), [r\"$\\mathregular{U}^{\\sigma}$\", \"U\", r\"$\\mathregular{B}^{\\sigma}$\", \"B\"]\n",
    ")\n",
    "ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "ax.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n",
    "ax.tick_params(axis=\"y\", left=False)\n",
    "ax.tick_params(axis=\"x\", labeltop=True, labelbottom=False, top=False, bottom=False)\n",
    "ax.set_xlabel(\"quiescent\")\n",
    "ax.set_ylabel(\"active\")\n",
    "\n",
    "cb = fig.colorbar(im, location=\"right\")\n",
    "cb.set_label(\"KL divergence\")\n",
    "cb.ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "cb.ax.set_yticks(np.linspace(kl_mean.min(), kl_mean.max(), 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig(figures / f\"output_kl_{noise}.pdf\", bbox_inches=\"tight\", pad_inches=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Average point-wise distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "distance_files = [\n",
    "    json.load(open(results / \"spatial_navigation\" / f\"distance_{noise}_{seed}.json\"))\n",
    "    for seed in range(5)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unbiased = True\n",
    "if unbiased:\n",
    "    un = \"un\"\n",
    "else:\n",
    "    un = \"\"\n",
    "\n",
    "distance = [\n",
    "    [\n",
    "        distance_files[seed][\"active\"][f\"noisy_{un}biased_0.0001_{seed}\"]\n",
    "        for seed in range(5)\n",
    "    ],\n",
    "    [\n",
    "        distance_files[seed][\"quiescent\"][f\"noisy_{un}biased_0.0001_{seed}\"]\n",
    "        for seed in range(5)\n",
    "    ],\n",
    "    [\n",
    "        distance_files[seed][\"active\"][f\"no-noise_{un}biased_{seed}\"]\n",
    "        for seed in range(5)\n",
    "    ],\n",
    "    [\n",
    "        distance_files[seed][\"quiescent\"][f\"no-noise_{un}biased_{seed}\"]\n",
    "        for seed in range(5)\n",
    "    ],\n",
    "    [\n",
    "        distance_files[seed][\"quiescent_noisy\"][f\"no-noise_{un}biased_{seed}\"]\n",
    "        for seed in range(5)\n",
    "    ],\n",
    "]\n",
    "distance = [np.array(d).flatten() for d in distance]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(2.7, 1.5))\n",
    "c = \"black\"\n",
    "ax.boxplot(\n",
    "    distance,\n",
    "    labels=[\n",
    "        \"NT\\nactive\",\n",
    "        \"NT\\nquiescent\",\n",
    "        \"DT\\nactive\",\n",
    "        \"DT\\nquiescent\",\n",
    "        \"DT\\nnoisy\\nquiescent\",\n",
    "    ],\n",
    "    patch_artist=True,\n",
    "    boxprops=dict(facecolor=\"gray\", color=c),\n",
    "    capprops=dict(color=c, linewidth=0),\n",
    "    whiskerprops=dict(color=c),\n",
    "    flierprops=dict(color=c, markeredgecolor=c),\n",
    "    medianprops=dict(color=c),\n",
    ")\n",
    "ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "ax.tick_params(axis=\"both\", which=\"minor\", labelsize=8)\n",
    "ax.tick_params(axis=\"x\", bottom=False)\n",
    "ax.spines[[\"right\", \"top\"]].set_visible(False)\n",
    "ax.set_ylabel(\"avg. distance\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig(figures / f\"output_distance_{noise}.pdf\", bbox_inches=\"tight\", pad_inches=0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
