{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e2da081",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Imports and hyperparameters\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "\n",
    "import datasets\n",
    "from helpers import (\n",
    "    gather_seeds,\n",
    "    simulate_models,\n",
    "    reach_stats,\n",
    "    dict_reach_stats,\n",
    "    dict_to_csv,\n",
    "    array_to_heatmap,\n",
    ")\n",
    "from plots import (\n",
    "    plot_cov_hull,\n",
    "    dataset_cmaps,\n",
    "    plot_reference_lines,\n",
    "    add_border_and_ticks,\n",
    ")\n",
    "from metrics import calc_all_metric\n",
    "\n",
    "os.makedirs(\"csv\", exist_ok=True)\n",
    "\n",
    "# Hyperparameters\n",
    "unmask_every = 3\n",
    "sigma_s = sigma_r = 1\n",
    "T_multiplier = 4\n",
    "radius = 0.1  # radius within which endpoints have been \"reached\"\n",
    "\n",
    "param_combos = []\n",
    "for lambda_v in [1, 0.9, 0.8, 0.7]:\n",
    "    for b_a in [0, 0.5, 1]:\n",
    "        param_combos.append({\"sigma_r\": sigma_r, \"lambda_v\": lambda_v, \"b_a\": b_a})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee541744",
   "metadata": {},
   "outputs": [],
   "source": [
    "endpoints = datasets.triangle_vertices_and_mus()[-1]\n",
    "groups = 6\n",
    "\n",
    "# Gather seeds\n",
    "seeds = gather_seeds(\"triangle_dataset\", f\"unmask_every_{unmask_every}\")\n",
    "seeds = [s for s in seeds if s != 4]  # seed 4 is bad\n",
    "seeds = list(set(seeds))  # keep only unique values, since there are noleak seeds, too\n",
    "print(seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1078ff2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Simulate models\n",
    "# replays is a dictionary of Tensors, awake_trajectories is one Tensor\n",
    "replays, awake_trajectories = simulate_models(\n",
    "    sigma_s=0.05,\n",
    "    hidden_dim=40,\n",
    "    T_multiplier=T_multiplier,\n",
    "    param_combos=param_combos,\n",
    "    seeds=seeds,\n",
    "    model_dir=\"results\",\n",
    "    model_prefix=\"triangle_dataset__unmask_every_3\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e03d71b",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=4, ncols=3, figsize=(10, 5), sharex=True, sharey=True)\n",
    "axs = axs.flatten()\n",
    "for i, p in enumerate(tqdm(param_combos)):\n",
    "    x_hat = replays[str(p)]\n",
    "    for j in range(len(seeds)):\n",
    "        for g in range(groups):\n",
    "            # x_hat has shape (T*T_multiplier, groups, len(seeds), N//groups, 2)\n",
    "            axs[i].plot(\n",
    "                x_hat[:, g, j, :, 0], x_hat[:, g, j, :, 1], c=f\"C{g}\", alpha=0.05\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53bd7ffe",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=4, ncols=3, figsize=(10, 5), sharex=True, sharey=True)\n",
    "axs = axs.flatten()\n",
    "for i, p in enumerate(tqdm(param_combos)):\n",
    "    x_hat = replays[str(p)]\n",
    "    for j in range(len(seeds)):\n",
    "        for g in range(groups):\n",
    "            # x_hat has shape (T*T_multiplier, groups, len(seeds), N//groups, 2)\n",
    "            axs[i].plot(\n",
    "                x_hat[:100, g, j, :, 0], x_hat[:100, g, j, :, 1], c=f\"C{g}\", alpha=0.05\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24a97516",
   "metadata": {},
   "outputs": [],
   "source": [
    "# flatten replays: combine the len(seeds) and num samples dimensions together.\n",
    "replays = {k: v.flatten(start_dim=2, end_dim=3) for (k, v) in replays.items()}\n",
    "# calculate reach time statistics\n",
    "reach_times, rt_meds, rt_avgs, rt_stds, failure_rates = dict_reach_stats(\n",
    "    replays, endpoints, radius\n",
    ")\n",
    "awake_rts, awake_rt_med, awake_rt_avg, awake_rt_std, _ = reach_stats(\n",
    "    awake_trajectories, endpoints, radius\n",
    ")\n",
    "# reshape medians and avgs for plotting\n",
    "delta_arrays = {\n",
    "    \"median\": (np.array(list(rt_meds.values())).reshape(4, 3) / awake_rt_med - 1) * 100,\n",
    "    \"average\": (np.array(list(rt_avgs.values())).reshape(4, 3) / awake_rt_avg - 1)\n",
    "    * 100,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbe73c4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = plt.cm.seismic  # c = 'seismic'\n",
    "fig, axs = plt.subplots(nrows=1, ncols=len(delta_arrays))\n",
    "vmin = min([a.min() for a in delta_arrays.values()])\n",
    "vmax = max([a.max() for a in delta_arrays.values()])\n",
    "for i, (k, v) in enumerate(delta_arrays.items()):\n",
    "    im = axs[i].imshow(v, cmap=cmap, vmin=vmin, vmax=vmax)\n",
    "    axs[i].set(\n",
    "        title=k,\n",
    "        xlabel=\"Adaptation strength $b_a$\",\n",
    "        ylabel=\"Underdampening constant $\\lambda$\",\n",
    "        xticks=[0, 1, 2],\n",
    "        xticklabels=[0, 0.5, 1],\n",
    "        yticks=[0, 1, 2, 3],\n",
    "        yticklabels=[1, 0.9, 0.8, 0.7],\n",
    "    )\n",
    "fig.colorbar(im, ax=axs.ravel().tolist(), location=\"bottom\")\n",
    "fig.suptitle(f\"Triangle task reach time stats\\nover {len(seeds)} seeds\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e1fc5c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save results\n",
    "array_to_heatmap(delta_arrays[\"median\"], \"csv/med_rt_heatmap_triangle.txt\")\n",
    "array_to_heatmap(delta_arrays[\"average\"], \"csv/avg_rt_heatmap_triangle.txt\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3ee5639",
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = [\"blue\", \"purple\", \"red\"]\n",
    "markers = [\"o\", \"s\", \"D\", \"^\"]\n",
    "xlabels = [\n",
    "    f\"$\\lambda_v = {pc['lambda_v']},$\\n$b_a = {pc['b_a']}$\" for pc in param_combos\n",
    "]\n",
    "\n",
    "\n",
    "def violinplot(array_dict):\n",
    "    plt.figure(figsize=(12, 4))\n",
    "    plt.xticks(np.arange(len(param_combos)) + 1, xlabels)\n",
    "    parts = plt.violinplot(list(array_dict.values()), showmedians=True)\n",
    "    for i, pc in enumerate(parts[\"bodies\"]):\n",
    "        pc.set(facecolor=colors[i % 3], alpha=1)\n",
    "    parts[\"cmedians\"].set_color(\"yellow\")\n",
    "\n",
    "\n",
    "def stemplot(scalar_dict):\n",
    "    plt.figure(figsize=(12, 2))\n",
    "    plt.xticks(range(len(param_combos)), xlabels)\n",
    "    ls = list(scalar_dict.values())\n",
    "    for i in range(len(param_combos)):\n",
    "        plt.plot([i] * 2, [0, ls[i]], c=colors[i % 3])\n",
    "        plt.plot(i, ls[i], marker=markers[i // 3], c=colors[i % 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c106a1ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "violinplot(reach_times)\n",
    "plt.title(\n",
    "    \"Triangle reach time distributions (medians are in yellow) (median awake is dashed black)\"\n",
    ")\n",
    "plt.plot(\n",
    "    [1, 1 + len(param_combos)],\n",
    "    [np.median(awake_rt_med)] * 2,\n",
    "    linestyle=\"--\",\n",
    "    c=\"black\",\n",
    "    alpha=0.5,\n",
    "    zorder=-1,\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1d30e99",
   "metadata": {},
   "outputs": [],
   "source": [
    "stemplot(failure_rates)\n",
    "plt.title(\"% of trials that failed to reach the end goals\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c098edc",
   "metadata": {},
   "outputs": [],
   "source": [
    "replay_vels = {\n",
    "    k: v[: len(awake_trajectories)].diff(dim=0).square().sum(-1).sqrt().flatten()\n",
    "    for (k, v) in replays.items()\n",
    "}\n",
    "awake_vels = awake_trajectories.diff(dim=0).square().sum(-1).sqrt().flatten()\n",
    "\n",
    "violinplot(replay_vels)\n",
    "plt.title(\n",
    "    \"Triangle velocity distributions (T_mult = 1) (medians are in yellow) (median awake is dashed black)\"\n",
    ")\n",
    "plt.plot(\n",
    "    [1, 1 + len(param_combos)],\n",
    "    [awake_vels.median()] * 2,\n",
    "    linestyle=\"--\",\n",
    "    c=\"black\",\n",
    "    alpha=0.5,\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daccea70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# the covariance matrices are singular, so we don't use KL\n",
    "wds = calc_all_metric(replays, awake_trajectories, \"wd\")\n",
    "stemplot(wds)\n",
    "plt.title(\"Triangle Wasserstein distances\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "485205fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save data\n",
    "os.makedirs(\"csv\", exist_ok=True)\n",
    "dict_to_csv(wds, col_param=\"b_a\", path=\"csv/wd_triangle.csv\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d81fc06",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=4, ncols=3, figsize=(10, 5), sharex=True, sharey=True)\n",
    "axs = axs.flatten()\n",
    "for i, p in enumerate(tqdm(param_combos)):\n",
    "    x_hat = replays[str(p)]\n",
    "    mean_hat = x_hat.mean(2)  # mean over samples\n",
    "    for g in range(groups):\n",
    "        axs[i].plot(mean_hat[:100, g, 0], mean_hat[:100, g, 1], c=f\"C{g}\")\n",
    "        axs[i].plot(\n",
    "            endpoints[:2, 0],\n",
    "            endpoints[:2, 1],\n",
    "            linestyle=\":\",\n",
    "            color=\"black\",\n",
    "            zorder=-1,\n",
    "            alpha=0.2,\n",
    "        )\n",
    "        axs[i].plot(\n",
    "            endpoints[1:3, 0],\n",
    "            endpoints[1:3, 1],\n",
    "            linestyle=\":\",\n",
    "            color=\"black\",\n",
    "            zorder=-1,\n",
    "            alpha=0.2,\n",
    "        )\n",
    "        axs[i].plot(\n",
    "            endpoints[4:, 0],\n",
    "            endpoints[4:, 1],\n",
    "            linestyle=\":\",\n",
    "            color=\"black\",\n",
    "            zorder=-1,\n",
    "            alpha=0.2,\n",
    "        )\n",
    "        axs[i].set(title=f\"$\\lambda_v = {p['lambda_v']}, b_a = {p['b_a']}$\")\n",
    "fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e737b1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save results\n",
    "os.makedirs(\"figures\", exist_ok=True)\n",
    "os.makedirs(\"figures/means\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "399d3c2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "T_awake = len(awake_trajectories)\n",
    "cmaps = dataset_cmaps(\"triangle\")\n",
    "xlim = (-0.1, 1.4)\n",
    "ylim = (-0.3, 1.15)\n",
    "\n",
    "for lambda_v in [1, 0.7]:\n",
    "    for b_a in [0, 1]:\n",
    "        fig, ax = plt.subplots(dpi=100)\n",
    "        ax.axis(\"off\")\n",
    "        ax.set(xlim=xlim, ylim=ylim)\n",
    "\n",
    "        x_hat = replays[str({\"sigma_r\": sigma_r, \"lambda_v\": lambda_v, \"b_a\": b_a})]\n",
    "        for group in range(groups):\n",
    "            group_hat = x_hat[:T_awake, group]\n",
    "            plot_cov_hull(group_hat, ax=ax, color=cmaps[group](0), alpha=0.3, zorder=-1)\n",
    "            plot_reference_lines(\n",
    "                ax=ax,\n",
    "                dataset=\"triangle\",\n",
    "                linestyle=\"dashed\",\n",
    "                color=\"black\",\n",
    "                alpha=0.15,\n",
    "                zorder=-1,\n",
    "                linewidth=4,\n",
    "            )\n",
    "            mean = group_hat.mean(1)  # mean trajectory\n",
    "            for t in range(T_awake - 1):\n",
    "                ax.plot(\n",
    "                    mean[t : t + 2, 0],\n",
    "                    mean[t : t + 2, 1],\n",
    "                    c=cmaps[group](t / T_awake),\n",
    "                    linewidth=7,\n",
    "                    zorder=1,\n",
    "                )\n",
    "\n",
    "        add_border_and_ticks(ax, xlim, ylim)\n",
    "\n",
    "        fig.savefig(\n",
    "            f\"figures/means/triangle_{lambda_v=},{b_a=}.svg\",\n",
    "            bbox_inches=\"tight\",\n",
    "            pad_inches=0,\n",
    "        )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
