{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analyzing results for Spatial Navigation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import pathlib\n",
    "import pickle\n",
    "import sys\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "\n",
    "from plots import plot_pos, scatter_pos, stemplot, violinplot, add_border_and_ticks\n",
    "from helpers import (\n",
    "    optimal_decoder,\n",
    "    gather_models,\n",
    "    gather_test_data,\n",
    "    momentum,\n",
    "    dict_to_csv,\n",
    "    wd_types_to_csvs,\n",
    "    array_to_heatmap,\n",
    ")\n",
    "from metrics import wasserstein, calc_wds  # , calc_kls\n",
    "\n",
    "os.makedirs(\"pickle\", exist_ok=True)\n",
    "replay_dict_path = \"pickle/replay_biased.pkl\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load models and data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "unmask_every = 3\n",
    "epoch = 15000\n",
    "# epoch = 5000 #2500\n",
    "t_quiescent = 500  # 1000 #1000 #250 #100\n",
    "quiescence = \"same\"\n",
    "lambda_scaling = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = pathlib.Path(\"../results\")\n",
    "noise = 0.0001\n",
    "config_name = f\"noisy_biased_unmask_every_{unmask_every}_{noise}\"\n",
    "seed_dirs = [\n",
    "    results / \"spatial_navigation\" / f\n",
    "    for f in os.listdir(results / \"spatial_navigation\")\n",
    "    if config_name in f\n",
    "]\n",
    "seed_dirs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load models and their test data\n",
    "models = gather_models(seed_dirs, epoch)\n",
    "test_obs, test_pos, test_inits = gather_test_data(models)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulating the model with various parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tau_a = 100\n",
    "lambda_v_vals = [1, 0.8, 0.6, 0.5]\n",
    "b_a_vals = [0, 1, 3, 5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "replay = {}\n",
    "quiescent_inputs = torch.zeros(models[0].task.batch_size, t_quiescent, 2)\n",
    "for lambda_v in lambda_v_vals:\n",
    "    for b_a in b_a_vals:\n",
    "        key = str({\"lambda_v\": lambda_v, \"tau_a\": tau_a, \"b_a\": b_a})\n",
    "        print(key)\n",
    "        replay[key] = []\n",
    "        for model, init in zip(models, test_inits):\n",
    "            pc_act = momentum(\n",
    "                model,\n",
    "                lambda_v=lambda_v,\n",
    "                tau_a=tau_a,\n",
    "                b_a=b_a,\n",
    "                quiescence=quiescence,\n",
    "                lambda_scaling=lambda_scaling,\n",
    "                x=quiescent_inputs,\n",
    "                init_state=init,\n",
    "            )[1]\n",
    "            replay[key].append(\n",
    "                optimal_decoder(\n",
    "                    model,\n",
    "                    pc_act,\n",
    "                    lr=(2e-1 if quiescence == \"scaled\" else 1e-1),\n",
    "                    verbose=True,\n",
    "                    iters=51,\n",
    "                )\n",
    "            )\n",
    "        replay[key] = torch.stack(replay[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save replays\n",
    "with open(replay_dict_path, \"wb\") as file:\n",
    "    pickle.dump(replay, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load replays\n",
    "with open(replay_dict_path, \"rb\") as file:\n",
    "    replay = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(replay.values())[0].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Decoded outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_idx = 0\n",
    "fig, axs = plt.subplots(nrows=4, ncols=4, figsize=(15, 15))\n",
    "axs = axs.flatten()\n",
    "for i, (k, v) in enumerate(replay.items()):\n",
    "    # plot_pos(axs[i], v, title = k[:k.index(' ')] + '\\n' + k[k.index(' ')+1:])\n",
    "    plot_pos(\n",
    "        axs[i],\n",
    "        v[seed_idx, :, :100],\n",
    "        title=k[: k.index(\" \")] + \"\\n\" + k[k.index(\" \") + 1 :],\n",
    "    )\n",
    "fig.suptitle(f\"{unmask_every=}, {epoch=}, t_quiescent=100, {tau_a=}, seed={seed_idx}\")\n",
    "fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(3, 3))\n",
    "plot_pos(ax, test_pos[seed_idx], title=\"test (true)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"../figures\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = matplotlib.colormaps[\"inferno\"]  # 'plasma'\n",
    "T_awake = test_pos.shape[2]  # 100\n",
    "xlim = (-0.9, 0.9)\n",
    "ylim = (-0.9, 0.9)\n",
    "\n",
    "for lv in (lambda_v_vals[0], lambda_v_vals[-1]):\n",
    "    for ba in (b_a_vals[0], b_a_vals[-1]):\n",
    "        fig, ax = plt.subplots(dpi=100)\n",
    "        ax.axis(\"off\")\n",
    "        ax.set(xlim=xlim, ylim=ylim)\n",
    "        add_border_and_ticks(\n",
    "            ax, xlim, ylim, [-0.8, 0, 0.8], [-0.8, 0, 0.8], extra_padding=(0.015, 0.02)\n",
    "        )  # gray\n",
    "\n",
    "        # x_hat = test_pos\n",
    "        x_hat = replay[str({\"lambda_v\": lv, \"tau_a\": tau_a, \"b_a\": ba})]\n",
    "        x_hat = x_hat.flatten(0, 1)  # combine the seed and sample dimensions\n",
    "        for t in range(T_awake - 1):\n",
    "            # plot every 5th sample\n",
    "            ax.plot(\n",
    "                x_hat[::5, t : t + 2, 0].T,\n",
    "                x_hat[::5, t : t + 2, 1].T,\n",
    "                c=cmap(1 - t / T_awake),\n",
    "                alpha=0.4,\n",
    "                linewidth=2.5,\n",
    "            )  # , zorder=T-t)\n",
    "        fig.savefig(\n",
    "            f\"../figures/biased_lambda_v={lv},b_a={ba}.png\",\n",
    "            bbox_inches=\"tight\",\n",
    "            pad_inches=0,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Wasserstein distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"../csv\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wds = calc_wds(replay, test_pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_to_csv(wds, \"b_a\", path=\"../csv/wds_biased.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "# calculate various variations of wasserstein distance\n",
    "wds_100, wds_250, wds_500 = {}, {}, {}\n",
    "for key in ['all', 'seed']:\n",
    "    seedwise = key == 'seed'\n",
    "    wds_100[key] = calc_wds(replay, test_pos, t_lim=100, timewise=False, seedwise=seedwise)\n",
    "    wds_250[key] = calc_wds(replay, test_pos, t_lim=250, timewise=False, seedwise=seedwise)\n",
    "    wds_500[key] = calc_wds(replay, test_pos, t_lim=500, timewise=False, seedwise=seedwise)\n",
    "wds_100['time'] = calc_wds(replay, test_pos, t_lim=100, timewise=True, seedwise=False)\n",
    "wds_100['timeseed'] = calc_wds(replay, test_pos, t_lim=100, timewise=True, seedwise=True)\n",
    "\n",
    "# save all the means of every wasserstein distance variation\n",
    "os.makedirs('../csv', exist_ok=True)\n",
    "wds_100_dfs = wd_types_to_csvs(wds_100, 100, 'biased', 'b_a', '../csv')\n",
    "wds_250_dfs = wd_types_to_csvs(wds_250, 250, 'biased', 'b_a', '../csv')\n",
    "wds_500_dfs = wd_types_to_csvs(wds_500, 500, 'biased', 'b_a', '../csv')\n",
    "\n",
    "\n",
    "# plot wasserstein distances\n",
    "colors = ['blue', 'purple', 'red', 'orange']\n",
    "linestyles = ['-', '--']\n",
    "fig, axs = plt.subplots(2,2, figsize=(15,5))\n",
    "axs = axs.flatten()\n",
    "\n",
    "for i, wd_df_set, T in zip(range(4), [wds_100_dfs]*2 + [wds_250_dfs, wds_500_dfs],\\\n",
    "                           [100]*2 + [250,500]):\n",
    "    key_list = ['all', 'seed'] if i != 1 else ['time', 'timeseed']\n",
    "    for j, key in enumerate(key_list):\n",
    "        for b_idx, b_a_val in enumerate(b_a_vals):\n",
    "            axs[i].plot(wd_df_set[key]['lambda_v'], wd_df_set[key][f'b_a={b_a_val}'],\n",
    "                        c=colors[b_idx], linestyle=linestyles[j])\n",
    "    axs[i].xaxis.set_inverted(True)\n",
    "    axs[i].set(title=f'{T} timesteps, {key_list}')\n",
    "fig.suptitle('Wasserstein distances')\n",
    "fig.tight_layout()\n",
    "\"\"\";"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "# Kernel densities, KL divergences\n",
    "kls_100, kde_100 = calc_kls(replay, test_pos, t_lim=100)\n",
    "kls_250, kde_250 = calc_kls(replay, test_pos, t_lim=250)\n",
    "\n",
    "stemplot({k:np.mean(v) for (k,v) in kls_100.items()}, xlabels)\n",
    "plt.title(\n",
    "f\"\"\"Unbiased RatInABox KL divergences, {unmask_every=}, {epoch=},\n",
    "t_quiescent=100, {quiescence=}, {lambda_scaling=}\"\"\")\n",
    "plt.tight_layout()\n",
    "\n",
    "stemplot({k:np.mean(v) for (k,v) in kls_250.items()}, xlabels)\n",
    "plt.title(\n",
    "f\"\"\"Unbiased RatInABox KL divergences, {unmask_every=}, {epoch=},\n",
    "t_quiescent=250, {quiescence=}, {lambda_scaling=}\"\"\")\n",
    "plt.tight_layout()\n",
    "\n",
    "\n",
    "stemplot(kls)\n",
    "plt.title(\n",
    "f\"\"\"Biased RatInABox KL divergences, {unmask_every=}, {epoch=},\n",
    "{t_quiescent=}, {quiescence=}, {lambda_scaling=}\"\"\")\n",
    "plt.tight_layout()\n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(nrows=4, ncols=4, figsize=(12,15))\n",
    "axs = axs.flatten()\n",
    "for i,(k,v) in enumerate(kde_plots.items()):\n",
    "    axs[i].imshow(v, cmap='coolwarm')\n",
    "    axs[i].set(title=k[:k.index(' ')] + '\\n' + k[k.index(' ')+1:])\n",
    "fig.suptitle(f'{unmask_every=}, {epoch=}, {t_quiescent=}, {tau_a=}')\n",
    "fig.tight_layout()\n",
    "''';"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## velocities and reach times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# velocity\n",
    "\n",
    "# make sure that the number of timesteps are the same\n",
    "# shape = seeds, N, T, 2\n",
    "t_vel = test_pos.shape[2]\n",
    "replay_vels = {\n",
    "    k: v[:, :, :t_vel].diff(dim=2).square().sum(-1).sqrt().flatten()\n",
    "    for (k, v) in replay.items()\n",
    "}\n",
    "test_vels = test_pos.diff(dim=2).square().sum(-1).sqrt().flatten()\n",
    "\n",
    "violinplot(replay_vels, xlabels)\n",
    "plt.title(\n",
    "    \"Biased rat velocity distributions (yellow = medians) (dashed black = median awake)\"\n",
    "    + f\"\\n{unmask_every=}, {epoch=}, {t_quiescent=}, {quiescence=}, {lambda_scaling=}\"\n",
    ")\n",
    "plt.hlines(\n",
    "    test_vels.median(),\n",
    "    1,\n",
    "    len(xlabels),\n",
    "    linestyle=\"--\",\n",
    "    color=\"black\",\n",
    "    alpha=0.5,\n",
    "    zorder=-1,\n",
    ")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inner_radius, outer_radius = 0.2, 0.45"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(4, 4))\n",
    "\n",
    "t = 99\n",
    "ax.scatter(\n",
    "    test_pos.flatten(0, 1)[:, t, 0], test_pos.flatten(0, 1)[:, t, 1], label=\"test\"\n",
    ")\n",
    "for idx in [0, 3, -3, -1]:\n",
    "    r_foo = list(replay.values())[idx].flatten(0, 1)\n",
    "    ax.scatter(r_foo[:, t, 0], r_foo[:, t, 1], alpha=0.4, label=idx)\n",
    "ax.add_patch(plt.Circle((0, 0), radius=outer_radius, fill=False))\n",
    "ax.add_patch(plt.Circle((0, 0), radius=inner_radius, fill=False))\n",
    "ax.set(title=\"stationary distributions\")\n",
    "ax.legend(loc=\"upper right\")\n",
    "fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reach_times, rt_meds, rt_avgs = {}, {}, {}\n",
    "for k, v in list(replay.items()) + [(\"awake\", test_pos)]:\n",
    "    assert v.ndim == 4  # seed, N, T, 2\n",
    "    radii = v.flatten(0, 1).square().sum(-1).sqrt()\n",
    "    success = ((radii > inner_radius) & (radii < outer_radius)).int()\n",
    "    assert 0 not in success.sum(\n",
    "        1\n",
    "    )  # make sure the conditions have been met at least once\n",
    "    reach_times[k] = success.argmax(dim=1)\n",
    "    rt_meds[k] = reach_times[k].float().median().item()\n",
    "    rt_avgs[k] = reach_times[k].float().mean().item()\n",
    "awake_reach_times = reach_times.pop(\"awake\")\n",
    "awake_rt_med = rt_meds.pop(\"awake\")\n",
    "awake_rt_avg = rt_avgs.pop(\"awake\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_arrays = {\n",
    "    \"median\": (np.array(list(rt_meds.values())).reshape(4, 4) / awake_rt_med - 1) * 100,\n",
    "    \"average\": (np.array(list(rt_avgs.values())).reshape(4, 4) / awake_rt_avg - 1)\n",
    "    * 100,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(delta_arrays[\"median\"], \"\\n\")\n",
    "print(delta_arrays[\"average\"], \"\\n\")\n",
    "print(\"min: \", min(delta_arrays[\"median\"].min(), delta_arrays[\"average\"].min()))\n",
    "print(\"max: \", max(delta_arrays[\"median\"].max(), delta_arrays[\"average\"].max()))\n",
    "\n",
    "os.makedirs(\"../csv/rt\", exist_ok=True)\n",
    "array_to_heatmap(delta_arrays[\"median\"], \"../csv/rt/med_rt_heatmap_biased.txt\")\n",
    "array_to_heatmap(delta_arrays[\"average\"], \"../csv/rt/avg_rt_heatmap_biased.txt\");"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
