{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e2da081",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Imports and hyperparameters\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from helpers import gather_seeds\n",
    "\n",
    "# Hyperparameters\n",
    "unmask_every = 3\n",
    "# dataset = 'triangle'\n",
    "dataset = \"tmaze\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee541744",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Gather seeds\n",
    "seeds = gather_seeds(f\"{dataset}_dataset\", f\"unmask_every_{unmask_every}__noleak\")\n",
    "if dataset == \"triangle\":\n",
    "    seeds = [s for s in seeds if s != 4]  # seed 4 is bad\n",
    "elif dataset == \"tmaze\":\n",
    "    seeds = [s for s in seeds if s != 3]  # seed 3 is bad\n",
    "seeds = list(set(seeds))  # keep only unique values, since there are noleak seeds, too\n",
    "print(seeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c9dad07",
   "metadata": {},
   "outputs": [],
   "source": [
    "losses = []\n",
    "noleak_losses = []\n",
    "for s in seeds:\n",
    "    prefix = f\"{dataset}_dataset__unmask_every_{unmask_every}\"\n",
    "    suffix = f\"__seed_{s :02d}__extra.npz\"\n",
    "    for middle, arr in zip([\"\", \"__noleak\"], [losses, noleak_losses]):\n",
    "        fname = prefix + middle + suffix\n",
    "        arr.append(np.load(\"results/\" + fname)[\"loss\"])\n",
    "losses = np.stack(losses)\n",
    "noleak_losses = np.stack(noleak_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bff110da",
   "metadata": {},
   "outputs": [],
   "source": [
    "# downsample by a factor of 50 via averaging\n",
    "ds = 50\n",
    "df = pd.DataFrame(\n",
    "    [\n",
    "        np.log(losses).mean(0).reshape(-1, ds).mean(1),\n",
    "        np.log(losses).std(0).reshape(-1, ds).mean(1),\n",
    "        np.log(noleak_losses).mean(0).reshape(-1, ds).mean(1),\n",
    "        np.log(noleak_losses).std(0).reshape(-1, ds).mean(1),\n",
    "    ]\n",
    ").T\n",
    "df.columns = [\"log_mean\", \"log_std\", \"noleak_log_mean\", \"noleak_log_std\"]\n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 3))\n",
    "\n",
    "axs[0].fill_between(\n",
    "    df.index, df.log_mean - df.log_std, df.log_mean + df.log_std, alpha=0.4\n",
    ")\n",
    "axs[0].fill_between(\n",
    "    df.index,\n",
    "    df.noleak_log_mean - df.noleak_log_std,\n",
    "    df.noleak_log_mean + df.noleak_log_std,\n",
    "    alpha=0.4,\n",
    ")\n",
    "axs[0].plot(df.index, df.log_mean, c=\"C0\", alpha=0.8)\n",
    "axs[0].plot(df.index, df.noleak_log_mean, c=\"C1\", alpha=0.8)\n",
    "\n",
    "axs[1].plot(np.log(losses).T, alpha=0.5, c=\"C0\")\n",
    "axs[1].plot(np.log(noleak_losses).T, alpha=0.5, c=\"C1\")\n",
    "\n",
    "fig.suptitle(f\"{dataset} loss\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe1da2a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(f\"csv/leak_losses_{dataset}.csv\")\n",
    "df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
