{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "312e3aa1",
   "metadata": {},
   "source": [
    "# Exploration in Biased Spatial Navigation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27463d9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import pathlib\n",
    "import pickle\n",
    "import sys\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "\n",
    "from plots import plot_pos, scatter_pos, stemplot, violinplot, add_border_and_ticks\n",
    "from helpers import (\n",
    "    optimal_decoder,\n",
    "    gather_models,\n",
    "    gather_test_data,\n",
    "    momentum,\n",
    "    dict_to_csv,\n",
    "    wd_types_to_csvs,\n",
    "    array_to_heatmap,\n",
    ")\n",
    "from metrics import wasserstein, calc_wds\n",
    "\n",
    "os.makedirs(\"pickle\", exist_ok=True)\n",
    "replay_dict_path = \"pickle/replay_biased.pkl\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2923f104",
   "metadata": {},
   "source": [
    "## Load models and data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "197d8cfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "unmask_every = 3\n",
    "epoch = 15000\n",
    "# epoch = 5000 # 2500\n",
    "t_quiescent = 500  # 1000 # 1000 # 250 # 100\n",
    "quiescence = \"same\"\n",
    "lambda_scaling = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea6699f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = pathlib.Path(\"../results\")\n",
    "noise = 0.0001\n",
    "config_name = f\"noisy_biased_unmask_every_{unmask_every}_{noise}\"\n",
    "seed_dirs = [\n",
    "    results / \"spatial_navigation\" / f\n",
    "    for f in os.listdir(results / \"spatial_navigation\")\n",
    "    if config_name in f\n",
    "]\n",
    "seed_dirs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ee6e678",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load models and their test data\n",
    "models = gather_models(seed_dirs, epoch)\n",
    "test_obs, test_pos, test_inits = gather_test_data(models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d1c533",
   "metadata": {},
   "outputs": [],
   "source": [
    "tau_a = 100\n",
    "lambda_v_vals = [1, 0.8, 0.6, 0.5]\n",
    "b_a_vals = [0, 1, 3, 5]\n",
    "\n",
    "# load replays\n",
    "with open(replay_dict_path, \"rb\") as file:\n",
    "    replay = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71f8e4a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_angles(x):\n",
    "    assert x.ndim == 3 and x.shape[-1] == 2\n",
    "    angles = torch.atan(x[..., 1] / x[..., 0])\n",
    "    dangles = angles.diff(dim=1)  # diff along time dimension\n",
    "    # some logic to undo the angle-wrapping in torch.atan\n",
    "    dangles[dangles < -torch.pi / 2] += torch.pi\n",
    "    dangles[dangles > torch.pi / 2] -= torch.pi\n",
    "    angles = angles[:, 0:1] + torch.cat(\n",
    "        (torch.zeros(len(angles), 1), dangles.cumsum(dim=1)), dim=1\n",
    "    )\n",
    "    return angles, dangles\n",
    "\n",
    "\n",
    "def calculate_path_lens(x):\n",
    "    assert x.ndim == 3 and x.shape[-1] == 2\n",
    "    # differentiate along T (dim 1),\n",
    "    # calculate dx (square sum sqrt),\n",
    "    # sum along T (now dim -1)\n",
    "    return x.diff(dim=1).square().sum(-1).sqrt().sum(-1)\n",
    "\n",
    "\n",
    "def calculate_mean_displacements(x):\n",
    "    assert x.ndim == 3 and x.shape[-1] == 2\n",
    "    # subtract x wrt x(t=0) (time is dim 1),\n",
    "    # calculate distance (square, sum, sqrt),\n",
    "    # take mean along samples (dim 0)\n",
    "    return (x - x[:, 0:1]).square().sum(-1).sqrt().mean(0)\n",
    "\n",
    "\n",
    "def calculate_variances(x):\n",
    "    assert x.ndim == 3 and x.shape[-1] == 2\n",
    "    # calculate variance of x along samples, take mean along dimensions.\n",
    "    # the output should be a vector of length = # timestpes (dim 1)\n",
    "    return x.var(0).mean(-1)\n",
    "\n",
    "\n",
    "dist_avg, dist_std, angl_avg, angl_std, md, variances = {}, {}, {}, {}, {}, {}\n",
    "for k, x_hat in replay.items():\n",
    "    # combine seed and sample dim. x_hat shape = (-1, timesteps, 2)\n",
    "    x_hat = x_hat.flatten(0, 1)\n",
    "    _, dangles = calculate_angles(x_hat)\n",
    "    path_lens = calculate_path_lens(x_hat)\n",
    "\n",
    "    # populate the dicts\n",
    "    dist_avg[k] = path_lens.mean().item()\n",
    "    dist_std[k] = path_lens.sum(-1).std().item()\n",
    "    angl_avg[k] = dangles.abs().sum(-1).mean().item()\n",
    "    angl_std[k] = dangles.abs().sum(-1).std().item()\n",
    "    md[k] = calculate_mean_displacements(x_hat)\n",
    "    variances[k] = calculate_variances(x_hat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ca6829b",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"../csv/exploration\", exist_ok=True)\n",
    "print(calculate_path_lens(test_pos.flatten(0, 1)).mean().item())\n",
    "dict_to_csv(\n",
    "    dist_avg, col_param=\"b_a\", path=\"../csv/exploration/exploration_distance_biased.csv\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7830f63e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(calculate_angles(test_pos.flatten(0, 1))[1].abs().sum(-1).mean().item())\n",
    "dict_to_csv(\n",
    "    angl_avg, col_param=\"b_a\", path=\"../csv/exploration/exploration_angular_biased.csv\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc39ca7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"../figures/mean_displacements\", exist_ok=True)\n",
    "os.makedirs(\"../figures/variance\", exist_ok=True)\n",
    "\n",
    "from cycler import cycler\n",
    "\n",
    "cmap = plt.get_cmap(\"inferno\")\n",
    "plt.rc(\"axes\", prop_cycle=cycler(color=cmap(np.linspace(0, 1, 4))))  # 4 lines\n",
    "\n",
    "plt.rcParams[\"font.family\"] = \"serif\"\n",
    "plt.rcParams[\"font.size\"] = 12  # previously 10\n",
    "\n",
    "fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(12, 4), sharex=True, sharey=True)\n",
    "for k, v in md.items():\n",
    "    k = eval(k)\n",
    "    b, l = k[\"b_a\"], k[\"lambda_v\"]\n",
    "    ax = axs[[0, 1, 3, 5].index(b)]\n",
    "    ax.plot(v, label=f'$\\lambda_v = {k[\"lambda_v\"]}$')\n",
    "    ax.set(title=f\"$b_a={b}$\")\n",
    "for ax in axs:\n",
    "    ax.plot(\n",
    "        calculate_mean_displacements(test_pos.flatten(0, 1)),\n",
    "        c=\"blue\",\n",
    "        marker=\"x\",\n",
    "        markevery=10,\n",
    "        zorder=-1,\n",
    "        label=\"Awake\",\n",
    "    )\n",
    "    ax.set(facecolor=\"gray\", xlabel=\"Timesteps\", ylabel=\"Mean Displacement\"),\n",
    "    ax.legend(loc=\"lower right\", fontsize=10)\n",
    "fig.suptitle(\"Biased Rat Mean Displacements\")\n",
    "fig.tight_layout()\n",
    "fig.savefig(\"../figures/mean_displacements/biased_md.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24448fe7",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(12, 4), sharex=True, sharey=True)\n",
    "for k, v in variances.items():\n",
    "    k = eval(k)\n",
    "    b, l = k[\"b_a\"], k[\"lambda_v\"]\n",
    "    ax = axs[[0, 1, 3, 5].index(b)]\n",
    "    ax.plot(v, label=f'$\\lambda_v = {k[\"lambda_v\"]}$')\n",
    "    ax.set(title=f\"$b_a={b}$\")\n",
    "for ax in axs:\n",
    "    ax.plot(\n",
    "        calculate_variances(test_pos.flatten(0, 1)),\n",
    "        c=\"blue\",\n",
    "        marker=\"x\",\n",
    "        markevery=10,\n",
    "        zorder=-1,\n",
    "        label=\"Awake\",\n",
    "    )\n",
    "    ax.set(facecolor=\"gray\", xlabel=\"Timesteps\", ylabel=\"Instantaneous Variance\"),\n",
    "    ax.legend(fontsize=10)\n",
    "fig.suptitle(\"Biased Rat Variances\")\n",
    "fig.tight_layout()\n",
    "fig.savefig(\"../figures/variance/biased_var.svg\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5a23d3d",
   "metadata": {},
   "source": [
    "### visualization of rotatory timesteps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8036e9c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = matplotlib.colormaps[\"inferno\"]  # 'plasma'\n",
    "T_awake = test_pos.shape[2]  # 100\n",
    "T_offset = 40\n",
    "xlim = (-0.5, 0.5)\n",
    "ylim = (-0.5, 0.5)\n",
    "\n",
    "for lv in (lambda_v_vals[0], lambda_v_vals[-1]):\n",
    "    for ba in (b_a_vals[0], b_a_vals[-1]):\n",
    "        fig, ax = plt.subplots(dpi=100)\n",
    "        ax.axis(\"off\")\n",
    "        ax.set(xlim=xlim, ylim=ylim)\n",
    "        add_border_and_ticks(\n",
    "            ax, xlim, ylim, [-0.9, 0, 0.9], [-0.9, 0, 0.9], extra_padding=(0.015, 0.02)\n",
    "        )  # gray\n",
    "\n",
    "        # x_hat = test_pos\n",
    "        x_hat = replay[str({\"lambda_v\": lv, \"tau_a\": tau_a, \"b_a\": ba})]\n",
    "        x_hat = x_hat.flatten(0, 1)  # combine the seed and sample dimensions\n",
    "        for t in range(T_offset, T_awake - 1):\n",
    "            # plot every 10th sample\n",
    "            ax.plot(\n",
    "                x_hat[::10, t : t + 2, 0].T,\n",
    "                x_hat[::10, t : t + 2, 1].T,\n",
    "                c=cmap(1 - (t - T_offset) / (T_awake - T_offset)),\n",
    "                alpha=0.4,\n",
    "                linewidth=2.5,\n",
    "            )  # , zorder=T-t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77ba08e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(dpi=100)\n",
    "ax.axis(\"off\")\n",
    "ax.set(xlim=xlim, ylim=ylim)\n",
    "add_border_and_ticks(\n",
    "    ax, xlim, ylim, [-0.9, 0, 0.9], [-0.9, 0, 0.9], extra_padding=(0.015, 0.02)\n",
    ")  # gray\n",
    "\n",
    "for t in range(T_offset, T_awake - 1):\n",
    "    # plot every 10th sample\n",
    "    ax.plot(\n",
    "        test_pos.flatten(0, 1)[::10, t : t + 2, 0].T,\n",
    "        test_pos.flatten(0, 1)[::10, t : t + 2, 1].T,\n",
    "        c=cmap(1 - (t - T_offset) / (T_awake - T_offset)),\n",
    "        alpha=0.4,\n",
    "        linewidth=2.5,\n",
    "    )  # , zorder=T-t)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
