{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3898e40e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Imports and hyperparameters\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "import datasets\n",
    "from helpers import gather_seeds, simulate_models, reach_stats, dict_reach_stats\n",
    "from plots import plot_cov_hull, dataset_cmaps, plot_reference_lines\n",
    "from metrics import calc_all_metric\n",
    "\n",
    "# Hyperparameters\n",
    "unmask_every_vals = (1, 3)\n",
    "sigma_r_vals = (1, 1.4)\n",
    "T_multiplier = 1\n",
    "\n",
    "param_combos = [{\"sigma_r\": sigma_r_vals[0]}, {\"sigma_r\": sigma_r_vals[1]}]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a539512d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Gather seeds\n",
    "seed_sets = [\n",
    "    set(gather_seeds(\"triangle_dataset\", f\"unmask_every_{u_e_val}\"))\n",
    "    for u_e_val in unmask_every_vals\n",
    "]\n",
    "seeds = list(set.intersection(*seed_sets))\n",
    "seeds = list(set(seeds))  # keep only unique values, since there are noleak seeds, too\n",
    "print(seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa3d656f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Simulate models\n",
    "# replays is a dictionary of Tensors, awake_trajectories is one Tensor\n",
    "replays_uev0, awake_trajectories = simulate_models(\n",
    "    sigma_s=1,\n",
    "    hidden_dim=40,\n",
    "    T_multiplier=T_multiplier,\n",
    "    param_combos=param_combos,\n",
    "    seeds=seeds,\n",
    "    model_dir=\"results\",\n",
    "    model_prefix=f\"triangle_dataset__unmask_every_{unmask_every_vals[0]}\",\n",
    ")\n",
    "\n",
    "replays_uev1, _ = simulate_models(\n",
    "    sigma_s=1,\n",
    "    hidden_dim=40,\n",
    "    T_multiplier=T_multiplier,\n",
    "    param_combos=param_combos,\n",
    "    seeds=seeds,\n",
    "    model_dir=\"results\",\n",
    "    model_prefix=f\"triangle_dataset__unmask_every_{unmask_every_vals[1]}\",\n",
    ")\n",
    "\n",
    "# combine the two dicts\n",
    "replays = {}\n",
    "key = lambda i, j: str(\n",
    "    {\"unmask_every\": unmask_every_vals[i], \"sigma_r\": sigma_r_vals[j]}\n",
    ")\n",
    "replays[key(0, 0)] = replays_uev0[str(param_combos[0])]\n",
    "replays[key(0, 1)] = replays_uev0[str(param_combos[1])]\n",
    "replays[key(1, 0)] = replays_uev1[str(param_combos[0])]\n",
    "replays[key(1, 1)] = replays_uev1[str(param_combos[1])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38a3b5a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 6), sharex=True, sharey=True)\n",
    "\n",
    "for i, unmask_every in enumerate(unmask_every_vals):\n",
    "    for j, unmask_every in enumerate(sigma_r_vals):\n",
    "        # shape = T, groups, seeds, N//groups, 2\n",
    "        replays[key(i, j)] = replays[key(i, j)].flatten(start_dim=2, end_dim=3)\n",
    "        x_hat = replays[key(i, j)]  # T, groups, seeds*N//groups, 2\n",
    "        for group in range(6):\n",
    "            group_replay = x_hat[:, group]  # shape = T, number of samples, 2\n",
    "            mean, std = group_replay.mean(1), group_replay.std(axis=1)\n",
    "            axs[i, j].plot(mean[:, 0], mean[:, 1], c=f\"C{group}\")\n",
    "            axs[i, j].fill_between(\n",
    "                mean[:, 0], (mean - std)[:, 1], (mean + std)[:, 1], alpha=0.3\n",
    "            )\n",
    "        axs[i, j].set(\n",
    "            title=f\"unmask_every = {unmask_every_vals[i]}, sigma_r = {sigma_r_vals[j]}\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5228d27a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(4, 2))\n",
    "for group in range(6):\n",
    "    group_replay = awake_trajectories[:, group]  # shape = T, number of samples, 2\n",
    "    mean, std = group_replay.mean(1), group_replay.std(axis=1)\n",
    "    plt.plot(mean[:, 0], mean[:, 1], c=f\"C{group}\")\n",
    "    plt.fill_between(mean[:, 0], (mean - std)[:, 1], (mean + std)[:, 1], alpha=0.3)\n",
    "plt.title(\"awake trajectories\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "553ca928",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stemplot(scalar_dict):\n",
    "    plt.figure(figsize=(12, 2))\n",
    "    plt.xticks(range(4), list(replays.keys()))\n",
    "    ls = list(scalar_dict.values())\n",
    "    for i in range(4):\n",
    "        plt.plot([i] * 2, [0, ls[i]], c=\"black\")\n",
    "        plt.plot(i, ls[i], marker=\"o\", c=\"black\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e23a2077",
   "metadata": {},
   "outputs": [],
   "source": [
    "wds = calc_all_metric(replays, awake_trajectories, metric=\"wd\")\n",
    "stemplot(wds)\n",
    "plt.plot(\n",
    "    range(4), [wds[key(1, 0)]] * 4, linestyle=\"--\", c=\"black\", alpha=0.5, zorder=-1\n",
    ")\n",
    "plt.title(\n",
    "    f\"triangle experiment Wasserstein distances over {len(seeds)} seeds (default is gray dashed)\"\n",
    ");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "865c5137",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert T_multiplier == 1\n",
    "\n",
    "fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 6), sharex=True, sharey=True)\n",
    "\n",
    "cmaps = dataset_cmaps(\"triangle\")\n",
    "T = len(list(replays.values())[0])\n",
    "\n",
    "for i, unmask_every in enumerate(unmask_every_vals):\n",
    "    for j, unmask_every in enumerate(sigma_r_vals):\n",
    "        x_hat = replays[key(i, j)]  # T, groups, seeds*N//groups, 2\n",
    "        for group in range(6):\n",
    "            group_hat = x_hat[:, group]\n",
    "            plot_cov_hull(\n",
    "                group_hat, ax=axs[i, j], color=cmaps[group](0), alpha=0.3, zorder=-1\n",
    "            )\n",
    "            plot_reference_lines(\n",
    "                ax=axs[i, j],\n",
    "                dataset=\"triangle\",\n",
    "                linestyle=\"dashed\",\n",
    "                color=\"black\",\n",
    "                alpha=0.5,\n",
    "                zorder=-1,\n",
    "                linewidth=3,\n",
    "            )\n",
    "            mean = group_hat.mean(1)  # mean trajectory\n",
    "            for t in range(T - 1):\n",
    "                axs[i, j].plot(\n",
    "                    mean[t : t + 2, 0],\n",
    "                    mean[t : t + 2, 1],\n",
    "                    c=cmaps[group](t / T),\n",
    "                    linewidth=4,\n",
    "                    zorder=1,\n",
    "                )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
