{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b21e2b63",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Imports and hyperparameters\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import pandas as pd\n",
    "import scipy\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "\n",
    "import datasets\n",
    "from helpers import gather_seeds, simulate_models, dict_to_csv\n",
    "from helpers_exploration import (\n",
    "    plot_region_assignments_over_time,\n",
    "    calc_exploration_metrics,\n",
    "    region_counts_histogram,\n",
    "    region_counts_histogram_df,\n",
    ")\n",
    "from plots import (\n",
    "    dataset_cmaps,\n",
    "    add_border_and_ticks,\n",
    "    plot_reference_lines,\n",
    "    plot_tmaze_power_diagram,\n",
    ")\n",
    "\n",
    "os.makedirs(\"csv\", exist_ok=True)\n",
    "\n",
    "# Hyperparameters\n",
    "unmask_every = 3\n",
    "sigma_s = sigma_r = 0.05\n",
    "T_multiplier = 4\n",
    "radius = 0.1  # radius within which endpoints have been \"reached\"\n",
    "dataset = \"tmaze\"\n",
    "\n",
    "param_combos = []\n",
    "for lambda_v in [1, 0.9, 0.8, 0.7]:\n",
    "    for b_a in [0, 0.5, 1]:\n",
    "        param_combos.append({\"sigma_r\": sigma_r, \"lambda_v\": lambda_v, \"b_a\": b_a})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a539512d",
   "metadata": {},
   "outputs": [],
   "source": [
    "groups = 2\n",
    "\n",
    "# Gather seeds\n",
    "seeds = gather_seeds(f\"{dataset}_dataset\", f\"unmask_every_{unmask_every}\")\n",
    "seeds = [s for s in seeds if s != 1]  # our seed 1 is not good\n",
    "seeds = list(set(seeds))  # keep only unique values, since there are noleak seeds, too\n",
    "print(seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa3d656f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Simulate models\n",
    "# replays is a dictionary of Tensors, awake_trajectories is one Tensor\n",
    "replays, awake_trajectories = simulate_models(\n",
    "    sigma_s=sigma_s,\n",
    "    hidden_dim=20,\n",
    "    T_multiplier=T_multiplier,\n",
    "    param_combos=param_combos,\n",
    "    seeds=seeds,\n",
    "    model_dir=\"results\",\n",
    "    model_prefix=f\"{dataset}_dataset__unmask_every_3\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33e7c434",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=4, ncols=3, figsize=(10, 5), sharex=True, sharey=True)\n",
    "axs = axs.flatten()\n",
    "for i, p in enumerate(tqdm(param_combos)):\n",
    "    x_hat = replays[str(p)]\n",
    "    for j in range(len(seeds)):\n",
    "        for g in range(groups):\n",
    "            # x_hat has shape (T*T_multiplier, groups, len(seeds), N//groups, 2)\n",
    "            axs[i].plot(\n",
    "                x_hat[:, g, j, :, 0], x_hat[:, g, j, :, 1], c=f\"C{g}\", alpha=0.03\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7c8f9f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(\n",
    "    nrows=len(seeds), ncols=groups, figsize=(4, 6), sharex=True, sharey=True\n",
    ")\n",
    "# shape = T, groups, seeds, N_per_group, 2\n",
    "param_idx = -1\n",
    "examine_replay = replays[str(param_combos[param_idx])]\n",
    "fig.suptitle(str(param_combos[param_idx]))\n",
    "for s_i in range(len(seeds)):\n",
    "    for g in range(groups):\n",
    "        groupseed_replay = examine_replay[:, g, s_i]\n",
    "        axs[s_i, g].plot(\n",
    "            groupseed_replay[..., 0],\n",
    "            groupseed_replay[..., 1],\n",
    "            c=f\"C{g}\",\n",
    "            alpha=0.2,\n",
    "            zorder=-1,\n",
    "        )\n",
    "        axs[s_i, g].scatter(\n",
    "            groupseed_replay[0, :, 0],\n",
    "            groupseed_replay[0, :, 1],\n",
    "            c=\"k\",\n",
    "            marker=\"o\",\n",
    "            s=20,\n",
    "            alpha=0.3,\n",
    "            zorder=1,\n",
    "        )\n",
    "        axs[s_i, g].scatter(\n",
    "            groupseed_replay[-1, :, 0],\n",
    "            groupseed_replay[-1, :, 1],\n",
    "            c=\"k\",\n",
    "            marker=\"x\",\n",
    "            s=20,\n",
    "            alpha=0.3,\n",
    "            zorder=1,\n",
    "        )\n",
    "        axs[s_i, g].set(title=f\"Group {g},\" + \"\\n\" + f\"seed index {s_i}\")\n",
    "fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feabd03d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=groups // 2, ncols=2, figsize=(15, 5))\n",
    "axs = axs.flatten()\n",
    "plot_region_assignments_over_time(dataset, axs, examine_replay, min_duration=20)\n",
    "fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b08484d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "every_region_count_value, region_count_avgs, every_distance, distance_avgs = (\n",
    "    calc_exploration_metrics(dataset, replays, min_duration=20)\n",
    ")\n",
    "\n",
    "df_rcnt_avg = dict_to_csv(\n",
    "    region_count_avgs,\n",
    "    col_param=\"b_a\",\n",
    "    path=f\"csv/exploration_{dataset}_region_counts_avg.csv\",\n",
    ")\n",
    "df_dist_avg = dict_to_csv(\n",
    "    distance_avgs, col_param=\"b_a\", path=f\"csv/exploration_{dataset}_distances_avg.csv\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea847918",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 2))\n",
    "for i, b_a in enumerate([0.0, 0.5, 1.0]):\n",
    "    x = df_rcnt_avg[\"lambda_v\"]\n",
    "    label = f\"{b_a=}\"\n",
    "    axs[0].plot(x, df_rcnt_avg[label], label=label)\n",
    "    axs[1].plot(x, df_dist_avg[f\"{b_a=}\"], label=label)\n",
    "for ax in axs:\n",
    "    ax.legend()\n",
    "    ax.xaxis.set_inverted(True)\n",
    "    ax.set(xlabel=\"$\\lambda_v$\")\n",
    "axs[0].set(ylabel=\"Mean regions entered\")\n",
    "axs[1].set(ylabel=\"Mean total distance\")\n",
    "fig.suptitle(\"Exploration summary statistics\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3dae334",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, p in enumerate(param_combos):\n",
    "    plt.violinplot(every_distance[str(p)].flatten(), positions=[i])\n",
    "plt.title(\"Distance distributions\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f8271f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "min_count, max_count = 2, 4\n",
    "width = 0.1\n",
    "x = np.arange(min_count, max_count + 1)\n",
    "for i, p_i in enumerate([0, 2, -3, -1]):\n",
    "    plt.bar(\n",
    "        x + width * (i - 1) - width / 2,\n",
    "        region_counts_histogram(\n",
    "            every_region_count_value[str(param_combos[p_i])].flatten(),\n",
    "            min_count=min_count,\n",
    "            max_count=max_count,\n",
    "        ),\n",
    "        width=width,\n",
    "        label=param_combos[p_i],\n",
    "    )\n",
    "plt.legend()\n",
    "plt.title(\"Distribution of region counts\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ec39603",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save individual groupseed histogram csvs\n",
    "os.makedirs(\"figures/exploration\", exist_ok=True)\n",
    "chosen_params = (param_combos[2], param_combos[-1])\n",
    "chosen_groupseeds = [(1, 3)]\n",
    "dfs = []\n",
    "for g, s in chosen_groupseeds:\n",
    "    dfs.append(\n",
    "        region_counts_histogram_df(\n",
    "            every_region_count_value,\n",
    "            [chosen_params[0], chosen_params[-1]],\n",
    "            \"lambda_v\",\n",
    "            group=g,\n",
    "            seed=s,\n",
    "            min_count=min_count,\n",
    "            max_count=max_count,\n",
    "        )\n",
    "    )\n",
    "    dfs[-1].to_csv(\n",
    "        f\"csv/exploration_{dataset}_hist_group={g}.csv\", index_label=\"regions\"\n",
    "    )\n",
    "[print(cps) for cps in chosen_params]\n",
    "[print(df.T) for df in dfs];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ff14811",
   "metadata": {},
   "outputs": [],
   "source": [
    "xlim = (-1.1, 1.1)\n",
    "ylim = (-0.2, 1.2)\n",
    "\n",
    "for g, s in chosen_groupseeds:\n",
    "    c = dataset_cmaps(dataset)[g](0)\n",
    "    for p_i, params in enumerate(chosen_params):\n",
    "        fig, ax = plt.subplots(dpi=100)\n",
    "        ax.axis(\"off\")\n",
    "        ax.set(xlim=xlim, ylim=ylim)\n",
    "        add_border_and_ticks(ax, xlim, ylim, [0, 1], [0, 1])  # , fill='lightgray')\n",
    "\n",
    "        # groupseed_replay shape = T, N, 2\n",
    "        groupseed_replay = replays[str(params)][:, g, s]\n",
    "        ax.plot(\n",
    "            groupseed_replay[..., 0], groupseed_replay[..., 1], c=c, alpha=0.7, zorder=2\n",
    "        )\n",
    "        ax.scatter(\n",
    "            groupseed_replay[0, :, 0],\n",
    "            groupseed_replay[0, :, 1],\n",
    "            facecolors=\"yellow\",\n",
    "            edgecolors=\"black\",\n",
    "            linewidths=1.5,\n",
    "            marker=\"o\",\n",
    "            alpha=1,\n",
    "            zorder=3,\n",
    "            s=400,\n",
    "        )\n",
    "        ax.scatter(\n",
    "            groupseed_replay[-1, :, 0],\n",
    "            groupseed_replay[-1, :, 1],\n",
    "            c=\"black\",\n",
    "            marker=\"x\",\n",
    "            alpha=1,\n",
    "            zorder=4,\n",
    "            s=400,\n",
    "        )\n",
    "        plot_reference_lines(\n",
    "            ax, dataset, colors=\"black\", linestyle=\"dashed\", zorder=1, alpha=1\n",
    "        )\n",
    "        plot_tmaze_power_diagram(\n",
    "            ax, ylim=ylim, colors=\"black\", linestyle=\":\", zorder=1, alpha=1\n",
    "        )\n",
    "\n",
    "        fig.savefig(\n",
    "            f\"figures/exploration/{dataset}_{g=},\"\n",
    "            + f'lambda_v={params[\"lambda_v\"]},b_a={params[\"b_a\"]}.svg',\n",
    "            bbox_inches=\"tight\",\n",
    "            pad_inches=0,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8460c22",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"figures/mean_displacements\", exist_ok=True)\n",
    "os.makedirs(\"figures/variance\", exist_ok=True)\n",
    "\n",
    "from cycler import cycler\n",
    "\n",
    "cmap = plt.get_cmap(\"inferno\")\n",
    "plt.rc(\"axes\", prop_cycle=cycler(color=cmap(np.linspace(0, 1, 4))))  # 4 lines\n",
    "\n",
    "plt.rcParams[\"font.family\"] = \"serif\"\n",
    "plt.rcParams[\"font.size\"] = 12  # previously 10\n",
    "\n",
    "# b_a [0, 0.5, 1]\n",
    "# lvs [1, 0.9, 0.8, 0.7]\n",
    "fig, axs = plt.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(12, 4))\n",
    "for k, v in list(replays.items()):\n",
    "    k = eval(k)\n",
    "    rs = v.flatten(1, -2)  # first timestep = time, last timestep = dimension (2)\n",
    "    ax = axs[[0, 0.5, 1].index(k[\"b_a\"])]\n",
    "    ax.set(title=f'$b_a = {k[\"b_a\"]}$')\n",
    "    ax.plot(\n",
    "        (rs - rs[0:1]).square().sum(-1).sqrt().mean(1),\n",
    "        label=f\"$\\lambda_v={k['lambda_v']}$\",\n",
    "    )\n",
    "\n",
    "for ax in axs:\n",
    "    ax.plot(\n",
    "        (awake_trajectories.flatten(1, -2)[:] - awake_trajectories.flatten(1, -2)[0:1])\n",
    "        .square()\n",
    "        .sum(-1)\n",
    "        .sqrt()\n",
    "        .mean(1),\n",
    "        c=\"blue\",\n",
    "        marker=\"x\",\n",
    "        markevery=10,\n",
    "        zorder=-1,\n",
    "        label=\"Awake\",\n",
    "    )\n",
    "    ax.legend(loc=\"lower right\", fancybox=True, fontsize=10, ncol=2)\n",
    "    ax.set(facecolor=\"gray\", xlabel=\"Timestep\", ylabel=\"Mean Dispalcement\")\n",
    "fig.suptitle(\"T-Maze Mean Displacements\")  # , y=1.1)\n",
    "fig.tight_layout()\n",
    "fig.savefig(\"figures/mean_displacements/tmaze_md.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "474df511",
   "metadata": {},
   "outputs": [],
   "source": [
    "# b_a [0, 0.5, 1]\n",
    "# lvs [1, 0.9, 0.8, 0.7]\n",
    "fig, axs = plt.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(12, 4))\n",
    "for k, v in list(replays.items()):\n",
    "    k = eval(k)\n",
    "    rs = v.flatten(1, -2)  # first timestep = time, last timestep = dimension (2)\n",
    "    ax = axs[[0, 0.5, 1].index(k[\"b_a\"])]\n",
    "    ax.set(title=f'$b_a = {k[\"b_a\"]}$')\n",
    "    ax.plot(rs.var(1).mean(-1), label=f\"$\\lambda_v={k['lambda_v']}$\")\n",
    "\n",
    "for ax in axs:\n",
    "    ax.plot(\n",
    "        awake_trajectories.flatten(1, -2).var(1).mean(-1),\n",
    "        c=\"blue\",\n",
    "        marker=\"x\",\n",
    "        markevery=10,\n",
    "        zorder=-1,\n",
    "        label=\"Awake\",\n",
    "    )\n",
    "    ax.legend(loc=\"lower right\", fancybox=True, fontsize=10, ncol=2)\n",
    "    ax.set(facecolor=\"gray\", xlabel=\"Timestep\", ylabel=\"Instantaneous Variance\")\n",
    "fig.suptitle(\"T-Maze Variances\")  # , y=1.1)\n",
    "fig.tight_layout()\n",
    "fig.savefig(\"figures/variance/tmaze_var.svg\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
