{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analyzing results for Spatial Navigation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import matplotlib as mpl\n",
    "import matplotlib.font_manager as fm\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import pathlib\n",
    "import pickle\n",
    "import sys\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "\n",
    "from plots import plot_pos, scatter_pos, stemplot, violinplot, add_border_and_ticks\n",
    "from helpers import (\n",
    "    optimal_decoder,\n",
    "    gather_models,\n",
    "    gather_test_data,\n",
    "    momentum,\n",
    "    dict_to_csv,\n",
    "    wd_types_to_csvs,\n",
    ")\n",
    "from metrics import wasserstein, calc_wds\n",
    "\n",
    "os.makedirs(\"pickle\", exist_ok=True)\n",
    "replay_dict_path = \"pickle/replay_unbiased.pkl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "active = \"#2ca7c5\"\n",
    "quiescent = \"#ee3233\"\n",
    "\n",
    "active_colors = [\n",
    "    \"#ffffff\",\n",
    "    \"#d4edf3\",\n",
    "    \"#aadbe7\",\n",
    "    \"#80cadc\",\n",
    "    \"#56b8d0\",\n",
    "    \"#2ca7c5\",\n",
    "    \"#2796b1\",\n",
    "    \"#1e7489\",\n",
    "    \"#165362\",\n",
    "    \"#0d323b\",\n",
    "]\n",
    "lcmap_active = mpl.colors.LinearSegmentedColormap.from_list(\n",
    "    \"lcmap_active\", active_colors\n",
    ")\n",
    "norm = mpl.colors.Normalize(vmin=0, vmax=1.5)\n",
    "lcmap_active = mpl.cm.ScalarMappable(norm=norm, cmap=lcmap_active)\n",
    "lcmap_active.set_array([])\n",
    "\n",
    "quiescent_colors = [\n",
    "    \"#ffffff\",\n",
    "    \"#fbd6d6\",\n",
    "    \"#f8adad\",\n",
    "    \"#f48484\",\n",
    "    \"#f15a5b\",\n",
    "    \"#ee3233\",\n",
    "    \"#d62d2d\",\n",
    "    \"#a62323\",\n",
    "    \"#771919\",\n",
    "    \"#470f0f\",\n",
    "]\n",
    "lcmap_quiescent = mpl.colors.LinearSegmentedColormap.from_list(\n",
    "    \"lcmap_quiescent\", quiescent_colors\n",
    ")\n",
    "norm = mpl.colors.Normalize(vmin=0, vmax=1.5)\n",
    "lcmap_quiescent = mpl.cm.ScalarMappable(norm=norm, cmap=lcmap_quiescent)\n",
    "lcmap_quiescent.set_array([])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load models and data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "unmask_every = 6  # 3\n",
    "epoch = 15000\n",
    "# epoch = 5000 # 2500\n",
    "t_quiescent = 500  # 1000 # 1000 # 250 # 100\n",
    "quiescence = \"same\"\n",
    "lambda_scaling = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = pathlib.Path(\"../results\")\n",
    "noise = 0.0001\n",
    "config_name = f\"noisy_unbiased_unmask_every_{unmask_every}_{noise}\"\n",
    "seed_dirs = [\n",
    "    results / \"spatial_navigation\" / f\n",
    "    for f in os.listdir(results / \"spatial_navigation\")\n",
    "    if config_name in f\n",
    "]\n",
    "seed_dirs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# seed_dirs = [seed_dirs[1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load models and their test data\n",
    "models = gather_models(seed_dirs, epoch)\n",
    "test_obs, test_pos, test_inits = gather_test_data(models)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decoding accuracy: fixed vs optimized decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fixed_pos = models[0].task.place_cells.decode_pos(test_obs[0])\n",
    "optimal_pos = optimal_decoder(models[0], test_obs[0], verbose=True)\n",
    "\n",
    "fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(10, 5), sharex=True, sharey=True)\n",
    "for i, (pos, title) in enumerate(\n",
    "    zip([test_pos[0], fixed_pos, optimal_pos], [\"True\", \"Fixed\", \"Optimized\"])\n",
    "):\n",
    "    plot_pos(axs[0, i], pos, title=title)\n",
    "    scatter_pos(axs[1, i], pos, title=title)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"position RMSE:\")\n",
    "print((test_pos[0] - fixed_pos).square().mean().sqrt().item())\n",
    "print((test_pos[0] - optimal_pos).square().mean().sqrt().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"position Wasserstein distances:\")\n",
    "print(wasserstein(test_pos[0].reshape(-1, 2), fixed_pos.reshape(-1, 2)))\n",
    "print(wasserstein(test_pos[0].reshape(-1, 2), optimal_pos.reshape(-1, 2)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulating the model with various parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tau_a = 100\n",
    "lambda_v_vals = [1, 0.9, 0.8, 0.7]\n",
    "b_a_vals = [0, 0.1, 0.2, 0.3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "replay = {}\n",
    "quiescent_inputs = torch.zeros(models[0].task.batch_size, t_quiescent, 2)\n",
    "for lambda_v in lambda_v_vals:\n",
    "    for b_a in b_a_vals:\n",
    "        key = str({\"lambda_v\": lambda_v, \"tau_a\": tau_a, \"b_a\": b_a})\n",
    "        print(key)\n",
    "        replay[key] = []\n",
    "        for model, init in zip(models, test_inits):\n",
    "            pc_act = momentum(\n",
    "                model,\n",
    "                lambda_v=lambda_v,\n",
    "                tau_a=tau_a,\n",
    "                b_a=b_a,\n",
    "                quiescence=quiescence,\n",
    "                lambda_scaling=lambda_scaling,\n",
    "                x=quiescent_inputs,\n",
    "                init_state=init,\n",
    "            )[1]\n",
    "            replay[key].append(\n",
    "                optimal_decoder(\n",
    "                    model,\n",
    "                    pc_act,\n",
    "                    lr=(2e-1 if quiescence == \"scaled\" else 1e-1),\n",
    "                    verbose=True,\n",
    "                    iters=51,\n",
    "                )\n",
    "            )\n",
    "        replay[key] = torch.stack(replay[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save replays\n",
    "with open(replay_dict_path, \"wb\") as file:\n",
    "    pickle.dump(replay, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load replays\n",
    "with open(replay_dict_path, \"rb\") as file:\n",
    "    replay = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(replay.values())[0].shape  # seeds, N, T, 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Decoded outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 0\n",
    "fig, axs = plt.subplots(nrows=4, ncols=4, figsize=(15, 15))\n",
    "axs = axs.flatten()\n",
    "for i, (k, v) in enumerate(replay.items()):\n",
    "    # plot_pos(axs[i], v, title = k[:k.index(' ')] + '\\n' + k[k.index(' ')+1:])\n",
    "    plot_pos(\n",
    "        axs[i], v[idx, :, :100], title=k[: k.index(\" \")] + \"\\n\" + k[k.index(\" \") + 1 :]\n",
    "    )\n",
    "fig.suptitle(f\"{unmask_every=}, {epoch=}, t_quiescent=100, {tau_a=}\")\n",
    "fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(3, 3))\n",
    "plot_pos(ax, test_pos[idx], title=\"test (true)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"../figures\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = mpl.colormaps[\"inferno\"]  # 'plasma'\n",
    "T_awake = test_pos.shape[2]  # 100\n",
    "xlim = (-1.12, 1.12)\n",
    "ylim = (-1.12, 1.12)\n",
    "\n",
    "for lv in (lambda_v_vals[0], lambda_v_vals[-1]):\n",
    "    for ba in (b_a_vals[0], b_a_vals[-1]):\n",
    "        fig, ax = plt.subplots(dpi=100)\n",
    "        ax.axis(\"off\")\n",
    "        ax.set(xlim=xlim, ylim=ylim)\n",
    "        add_border_and_ticks(\n",
    "            ax, xlim, ylim, [-1, 0, 1], [-1, 0, 1], extra_padding=(0.015, 0.02)\n",
    "        )  # gray\n",
    "        # x_hat = test_pos\n",
    "        x_hat = replay[str({\"lambda_v\": lv, \"tau_a\": tau_a, \"b_a\": ba})]\n",
    "        x_hat = x_hat.flatten(0, 1)  # combine the seed and sample dimensions\n",
    "        for t in range(T_awake - 1):\n",
    "            # plot every 5th sample\n",
    "            ax.plot(\n",
    "                x_hat[::5, t : t + 2, 0].T,\n",
    "                x_hat[::5, t : t + 2, 1].T,\n",
    "                c=cmap(1 - t / T_awake),\n",
    "                alpha=0.5,\n",
    "                linewidth=2.5,\n",
    "            )  # , zorder=T-t)\n",
    "        fig.savefig(\n",
    "            f\"../figures/unbiased_lambda_v={lv},b_a={ba}.png\",\n",
    "            bbox_inches=\"tight\",\n",
    "            pad_inches=0,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Wasserstein distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"../csv\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wds = calc_wds(replay, test_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_to_csv(wds, \"b_a\", path=\"../csv/wds_unbiased.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### velocity and distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# velocity\n",
    "# make sure that the number of timesteps are the same\n",
    "# shape = seeds, N, T, 2\n",
    "t_vel = test_pos.shape[2]\n",
    "replay_vels = {\n",
    "    k: v[:, :, :t_vel].diff(dim=2).square().sum(-1).sqrt() for (k, v) in replay.items()\n",
    "}\n",
    "test_vels = test_pos.diff(dim=2).square().sum(-1).sqrt()\n",
    "\n",
    "replay_dists = {k: v.sum(-1) for k, v in replay_vels.items()}\n",
    "test_dists = test_vels.sum(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist_avg = {k: v.mean().item() for k, v in replay_dists.items()}\n",
    "dist_med = {k: v.median().item() for k, v in replay_dists.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist_med"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xlabels = [\n",
    "    f\"$\\lambda_v={eval(k)['lambda_v']}$, \\n $b_a={eval(k)['b_a']}\"\n",
    "    for k in replay.keys()\n",
    "]\n",
    "\n",
    "violinplot({k: v.sum(-1).flatten() for k, v in replay_vels.items()}, xlabels)\n",
    "plt.title(\n",
    "    \"Unbiased rat distance distributions (yellow = medians) (dashed black = median awake)\"\n",
    "    + f\"\\n{unmask_every=}, {epoch=}, {t_quiescent=}, {quiescence=}, {lambda_scaling=}\"\n",
    ")\n",
    "plt.hlines(\n",
    "    test_vels.sum(-1).median(),\n",
    "    1,\n",
    "    len(xlabels),\n",
    "    linestyle=\"--\",\n",
    "    color=\"black\",\n",
    "    alpha=0.5,\n",
    "    zorder=-1,\n",
    ")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xlabels = [\n",
    "    f\"$\\lambda_v={eval(k)['lambda_v']}$, \\n $b_a={eval(k)['b_a']}\"\n",
    "    for k in replay.keys()\n",
    "]\n",
    "\n",
    "violinplot({k: v.flatten() for k, v in replay_vels.items()}, xlabels)\n",
    "plt.title(\n",
    "    \"Unbiased rat velocity distributions (yellow = medians) (dashed black = median awake)\"\n",
    "    + f\"\\n{unmask_every=}, {epoch=}, {t_quiescent=}, {quiescence=}, {lambda_scaling=}\"\n",
    ")\n",
    "plt.hlines(\n",
    "    test_vels.median(),\n",
    "    1,\n",
    "    len(xlabels),\n",
    "    linestyle=\"--\",\n",
    "    color=\"black\",\n",
    "    alpha=0.5,\n",
    "    zorder=-1,\n",
    ")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
