{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "312e3aa1",
   "metadata": {},
   "source": [
    "# Exploration in Unbiased Spatial Navigation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27463d9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import pathlib\n",
    "import pickle\n",
    "import sys\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "\n",
    "from plots import plot_pos, scatter_pos, stemplot, violinplot, add_border_and_ticks\n",
    "from helpers import (\n",
    "    optimal_decoder,\n",
    "    gather_models,\n",
    "    gather_test_data,\n",
    "    momentum,\n",
    "    dict_to_csv,\n",
    "    wd_types_to_csvs,\n",
    "    array_to_heatmap,\n",
    ")\n",
    "from metrics import wasserstein, calc_wds  # , calc_kls\n",
    "\n",
    "os.makedirs(\"pickle\", exist_ok=True)\n",
    "replay_dict_path = \"pickle/replay_unbiased.pkl\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2923f104",
   "metadata": {},
   "source": [
    "## Load models and data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "197d8cfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "unmask_every = 6\n",
    "epoch = 15000\n",
    "# epoch = 5000 #2500\n",
    "t_quiescent = 500  # 1000 #1000 #250 #100\n",
    "quiescence = \"same\"\n",
    "lambda_scaling = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea6699f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = pathlib.Path(\"../results\")\n",
    "noise = 0.0001\n",
    "config_name = f\"noisy_unbiased_unmask_every_{unmask_every}_{noise}\"\n",
    "seed_dirs = [\n",
    "    results / \"spatial_navigation\" / f\n",
    "    for f in os.listdir(results / \"spatial_navigation\")\n",
    "    if config_name in f\n",
    "]\n",
    "seed_dirs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ee6e678",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load models and their test data\n",
    "models = gather_models(seed_dirs, epoch)\n",
    "test_obs, test_pos, test_inits = gather_test_data(models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d1c533",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load replays\n",
    "with open(replay_dict_path, \"rb\") as file:\n",
    "    replay = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71f8e4a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "T_awake = test_pos.shape[-2]\n",
    "\n",
    "\n",
    "def calculate_path_lens(x):\n",
    "    assert x.ndim == 3 and x.shape[-1] == 2\n",
    "    # differentiate along T (dim 1),\n",
    "    # calculate dx (square sum sqrt),\n",
    "    # sum along T (now dim -1)\n",
    "    return x.diff(dim=1).square().sum(-1).sqrt().sum(-1)\n",
    "\n",
    "\n",
    "def calculate_mean_displacements(x):\n",
    "    assert x.ndim == 3 and x.shape[-1] == 2\n",
    "    # subtract x wrt x(t=0) (time is dim 1),\n",
    "    # calculate distance (square, sum, sqrt),\n",
    "    # take mean along samples (dim 0)\n",
    "    return (x - x[:, 0:1]).square().sum(-1).sqrt().mean(0)\n",
    "\n",
    "\n",
    "def calculate_variances(x):\n",
    "    assert x.ndim == 3 and x.shape[-1] == 2\n",
    "    # calculate variance of x along samples, take mean along dimensions.\n",
    "    # the output should be a vector of length = # timestpes (dim 1)\n",
    "    return x.var(0).mean(-1)\n",
    "\n",
    "\n",
    "dist_avg, dist_std, post_avg, post_std, md, variances = {}, {}, {}, {}, {}, {}\n",
    "for k, x_hat in replay.items():\n",
    "    # combine seed and sample dim. x_hat shape = (-1, timesteps, 2)\n",
    "    x_hat = x_hat.flatten(0, 1)\n",
    "    path_lens = calculate_path_lens(x_hat)\n",
    "    post_lens = calculate_path_lens(x_hat[:, T_awake:])  # everything after T_awake\n",
    "\n",
    "    # populate the dicts\n",
    "    dist_avg[k] = path_lens.mean().item()\n",
    "    dist_std[k] = path_lens.sum(-1).std().item()\n",
    "    post_avg[k] = post_lens.mean().item()\n",
    "    post_std[k] = post_lens.sum(-1).std().item()\n",
    "    md[k] = calculate_mean_displacements(x_hat)\n",
    "    variances[k] = calculate_variances(x_hat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ca6829b",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"../csv/exploration\", exist_ok=True)\n",
    "print(calculate_path_lens(test_pos.flatten(0, 1)).mean().item())\n",
    "dict_to_csv(\n",
    "    dist_avg,\n",
    "    col_param=\"b_a\",\n",
    "    path=\"../csv/exploration/exploration_distance_unbiased.csv\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7830f63e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Dispalcement of test_pos after {T_awake=} is 0 by definition\")\n",
    "dict_to_csv(\n",
    "    post_avg, col_param=\"b_a\"\n",
    ")  # , path='../csv/exploration/exploration_post_unbiased.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1181f4a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"../figures/mean_displacements\", exist_ok=True)\n",
    "os.makedirs(\"../figures/variance\", exist_ok=True)\n",
    "\n",
    "from cycler import cycler\n",
    "\n",
    "cmap = plt.get_cmap(\"inferno\")\n",
    "plt.rc(\"axes\", prop_cycle=cycler(color=cmap(np.linspace(0, 1, 4))))  # 4 lines\n",
    "\n",
    "plt.rcParams[\"font.family\"] = \"serif\"\n",
    "plt.rcParams[\"font.size\"] = 12  # previously 10\n",
    "\n",
    "fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(12, 4), sharex=True, sharey=True)\n",
    "for k, v in md.items():\n",
    "    k = eval(k)\n",
    "    b, l = k[\"b_a\"], k[\"lambda_v\"]\n",
    "    ax = axs[[0, 0.1, 0.2, 0.3].index(b)]\n",
    "    ax.plot(v, label=f'$\\lambda_v = {k[\"lambda_v\"]}$')\n",
    "    ax.set(title=f\"$b_a={b}$\")\n",
    "for ax in axs:\n",
    "    ax.plot(\n",
    "        calculate_mean_displacements(test_pos.flatten(0, 1)),\n",
    "        c=\"blue\",\n",
    "        marker=\"x\",\n",
    "        markevery=10,\n",
    "        zorder=-1,\n",
    "        label=\"Awake\",\n",
    "    )\n",
    "    ax.set(facecolor=\"gray\", xlabel=\"Timesteps\", ylabel=\"Mean Displacement\"),\n",
    "    ax.legend(loc=\"lower right\", fontsize=10)\n",
    "fig.suptitle(\"Unbiased Rat Mean Displacements\")\n",
    "fig.tight_layout()\n",
    "fig.savefig(\"../figures/mean_displacements/unbiased_md.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae7c568f",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(12, 4), sharex=True, sharey=True)\n",
    "for k, v in variances.items():\n",
    "    k = eval(k)\n",
    "    b, l = k[\"b_a\"], k[\"lambda_v\"]\n",
    "    ax = axs[[0, 0.1, 0.2, 0.3].index(b)]\n",
    "    ax.plot(v, label=f'$\\lambda_v = {k[\"lambda_v\"]}$')\n",
    "    ax.set(title=f\"$b_a={b}$\")\n",
    "for ax in axs:\n",
    "    ax.plot(\n",
    "        calculate_variances(test_pos.flatten(0, 1)),\n",
    "        c=\"blue\",\n",
    "        marker=\"x\",\n",
    "        markevery=10,\n",
    "        zorder=-1,\n",
    "        label=\"Awake\",\n",
    "    )\n",
    "    ax.set(facecolor=\"gray\", xlabel=\"Timesteps\", ylabel=\"Instantaneous Variance\"),\n",
    "    ax.legend(fontsize=10)\n",
    "fig.suptitle(\"Unbiased Rat Variances\")\n",
    "fig.tight_layout()\n",
    "fig.savefig(\"../figures/variance/unbiased_var.svg\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
