{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import os\n",
    "from utils.vis_tool import walk_through\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning Curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_style(\"whitegrid\", {\"grid.linestyle\": \"--\"})\n",
    "plt.rcParams[\"figure.dpi\"] = 300\n",
    "plt.rcParams[\"figure.figsize\"] = (4, 3)\n",
    "plt.rcParams[\"axes.labelsize\"] = 15\n",
    "plt.rcParams[\"axes.titlesize\"] = 12\n",
    "plt.rcParams[\"xtick.labelsize\"] = 14\n",
    "plt.rcParams[\"ytick.labelsize\"] = 14\n",
    "plt.rcParams[\"legend.fontsize\"] = 13  # 10\n",
    "plt.rcParams[\"axes.grid\"] = True\n",
    "plt.rcParams[\"legend.loc\"] = \"best\"\n",
    "plt.rcParams[\"lines.linewidth\"] = 1.5\n",
    "plt.rcParams[\"axes.formatter.useoffset\"] = False\n",
    "plt.rcParams[\"axes.formatter.offset_threshold\"] = 1\n",
    "# plt.rcParams[\"font.size\"] = 8\n",
    "plt.rcParams[\"font.family\"] = \"serif\"\n",
    "plt.rcParams[\"font.serif\"] = [\"Liberation Serif\"]\n",
    "plt.rcParams[\"text.usetex\"] = True\n",
    "\n",
    "separate_legend = True  # False  #\n",
    "\n",
    "ablation = \"\"  # \"gpt\" # ,\"lstm\" #\n",
    "hue = (\n",
    "    # \"config_rl.algo\"\n",
    "    \"seq\"\n",
    "    # \"n_layer\"\n",
    ")\n",
    "style = None\n",
    "\n",
    "\n",
    "def query_fn(flags):\n",
    "    if ablation != \"\":\n",
    "        if flags[\"config_seq\"][\"model\"][\"seq_model_config\"][\"name\"] != ablation:\n",
    "            return False\n",
    "\n",
    "    return True\n",
    "\n",
    "\n",
    "for env_len, end in zip(\n",
    "    [50, 100, 250, 500, 750, 1000, 1250, 1500], [1e6, 2e6, 4e6, 4e6, 6e6, 8e6, 1e7, 1e7]\n",
    "):\n",
    "    env_name = \"Passive T-Maze\"\n",
    "    path = f\"logs/tmaze_passive/{env_len}\"\n",
    "    title_tag = \"memory length\"\n",
    "    metric = \"return\"\n",
    "\n",
    "    # for env_len, end in zip([20, 50, 100, 250, 500], [0.8e6, 2e6, 4e6, 7e6, 7e6]):\n",
    "    #     env_name = \"Active T-Maze\"\n",
    "    #     path = f\"logs/tmaze_active/{env_len}\"\n",
    "    #     title_tag = \"credit assignment length\"\n",
    "    #     metric =   \"return\"\n",
    "\n",
    "    # for env_len, end in zip([60, 120, 250, 500, 750, 1000],\n",
    "    #             [3.6e6, 4e6, 4e6, 5.3e6, 7.8e6, 10e6]):\n",
    "    #     env_name = \"Passive Visual Match\"\n",
    "    #     path = f\"logs/visual_match/{env_len}\"\n",
    "    #     title_tag = \"memory length\"\n",
    "    #     metric = \"return\" #   \"success\"\n",
    "\n",
    "    # for env_len, end in zip([60, 120, 250, 500], [3.4e6, 5.8e6, 7.5e6, 7.5e6]):\n",
    "    #     env_name = \"Key-to-Door\"\n",
    "    #     path = f\"logs/key_to_door/{env_len}\"\n",
    "    #     title_tag = \"credit assignment length\"\n",
    "    #     metric =   \"success\"\n",
    "\n",
    "    df = walk_through(\n",
    "        path,\n",
    "        metric,\n",
    "        query_fn,\n",
    "        start=0,\n",
    "        end=end,\n",
    "        steps=300,\n",
    "        window=10,\n",
    "    )\n",
    "    df = df.fillna(False)\n",
    "\n",
    "    # custom functions to reduce flags\n",
    "    df[\"seq\"] = df[\"config_seq.model.seq_model_config.name\"].str.upper()\n",
    "    if \"config_seq.model.seq_model_config.n_layer\" in df:\n",
    "        df[\"n_layer\"] = df[\"config_seq.model.seq_model_config.n_layer\"]\n",
    "\n",
    "    ans = sns.lineplot(\n",
    "        data=df,\n",
    "        x=\"env_steps\",\n",
    "        y=metric,\n",
    "        palette=\"Dark2\",\n",
    "        hue=hue,\n",
    "        hue_order=np.sort(df[hue].unique()) if hue is not None else None,\n",
    "        style=style,\n",
    "        style_order=np.sort(df[style].unique()) if style is not None else None,\n",
    "    )\n",
    "    if \"loss\" in metric:\n",
    "        ans.set_yscale(\"log\")\n",
    "    if separate_legend:\n",
    "        ans.legend().set_visible(False)\n",
    "    else:\n",
    "        ans.legend(framealpha=0.2)  # must use the returned ans\n",
    "\n",
    "    plt.xlim(0, end)\n",
    "    plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0, 0))  # default [-5, 6]\n",
    "    plt.title(f\"{env_name} w/ {title_tag} of {env_len}\")\n",
    "\n",
    "    os.makedirs(\n",
    "        \"plts\", exist_ok=True\n",
    "    )  # use flattened folder for easy upload in overleaf\n",
    "    plt.savefig(\n",
    "        f\"plts/{path.split('/')[-2]}-{env_len}_{metric}_{ablation}_{hue}_{style}\"\n",
    "        + (\"\" if separate_legend else \"_leg\")\n",
    "        + \".pdf\",\n",
    "        bbox_inches=\"tight\",\n",
    "        pad_inches=0.03,\n",
    "    )  # default 0.1\n",
    "    plt.show()\n",
    "    plt.close()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Aggregation Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_style(\"whitegrid\", {\"grid.linestyle\": \"--\"})\n",
    "plt.rcParams[\"figure.dpi\"] = 300\n",
    "plt.rcParams[\"figure.figsize\"] = (4, 3)\n",
    "plt.rcParams[\"axes.labelsize\"] = 12\n",
    "plt.rcParams[\"axes.titlesize\"] = 12\n",
    "plt.rcParams[\"xtick.labelsize\"] = 12\n",
    "plt.rcParams[\"ytick.labelsize\"] = 12\n",
    "plt.rcParams[\"legend.fontsize\"] = 12\n",
    "plt.rcParams[\"axes.grid\"] = True\n",
    "plt.rcParams[\"legend.loc\"] = \"best\"\n",
    "plt.rcParams[\"lines.linewidth\"] = 1.5\n",
    "plt.rcParams[\"axes.formatter.useoffset\"] = False\n",
    "plt.rcParams[\"axes.formatter.offset_threshold\"] = 1\n",
    "# plt.rcParams[\"font.size\"] = 8\n",
    "plt.rcParams[\"font.family\"] = \"serif\"\n",
    "plt.rcParams[\"font.serif\"] = [\"Liberation Serif\"]\n",
    "plt.rcParams[\"text.usetex\"] = True\n",
    "\n",
    "\n",
    "def query_fn(flags):\n",
    "    if ablation != \"\":\n",
    "        if flags[\"config_seq\"][\"model\"][\"seq_model_config\"][\"name\"] != ablation:\n",
    "            return False\n",
    "\n",
    "    # if ablation == \"\":\n",
    "    #     return flags['config_seq']['model']['seq_model_config']['n_layer'] == 1 # 2, 4\n",
    "\n",
    "    return True\n",
    "\n",
    "\n",
    "ablation = \"\"  # \"gpt\" # \"lstm\"\n",
    "\n",
    "hue = (\n",
    "    \"seq\"\n",
    "    # \"n_layer\"\n",
    ")\n",
    "style = hue\n",
    "\n",
    "dfs = []\n",
    "\n",
    "for env_len, end in zip(\n",
    "    [50, 100, 250, 500, 750, 1000, 1250, 1500], [1e6, 2e6, 4e6, 4e6, 6e6, 8e6, 1e7, 1e7]\n",
    "):\n",
    "    env_name = \"Passive T-Maze\"\n",
    "    path = f\"logs/tmaze_passive/{env_len}\"\n",
    "    metric = \"return\"\n",
    "    x_axis = r\"(Easy) $\\leftarrow$ Memory length $\\rightarrow$ (Hard)\"\n",
    "    y_label = \"Optimal agent w/o memory\"\n",
    "    y_value = 0.5\n",
    "    show_legend = True\n",
    "\n",
    "    # for env_len, end in zip([20, 50, 100, 250, 500], [0.8e6, 2e6, 4e6, 7e6, 7e6]):\n",
    "    #     env_name = \"Active T-Maze\"\n",
    "    #     path = f\"logs/tmaze_active/{env_len}\"\n",
    "    #     metric = \"return\"\n",
    "    #     x_axis = r\"(Easy) $\\leftarrow$ Credit assignment length $\\rightarrow$ (Hard)\"\n",
    "    #     y_label = \"Optimal agent w/o credit assignment\"\n",
    "    #     y_value = 0.5\n",
    "    #     show_legend = True\n",
    "\n",
    "    # for env_len, end in zip([60, 120, 250, 500, 750, 1000],\n",
    "    #             [3.6e6, 4e6, 4e6, 5.3e6, 7.8e6, 10e6]):\n",
    "    #     env_name = \"Passive Visual Match\"\n",
    "    #     path = f\"logs/visual_match/{env_len}\"\n",
    "    #     metric = \"success\"\n",
    "    #     x_axis = r\"(Easy) $\\leftarrow$ Memory length $\\rightarrow$ (Hard)\"\n",
    "    #     y_label = \"Optimal agent w/o memory\"\n",
    "    #     y_value = 1/3\n",
    "    #     show_legend = False\n",
    "\n",
    "    # for env_len, end in zip([60, 120, 250, 500], [3.4e6, 5.8e6, 7.5e6, 7.5e6]):\n",
    "    #     env_name = \"Key-to-Door\"\n",
    "    #     path = f\"logs/key_to_door/{env_len}\"\n",
    "    #     metric = \"success\"\n",
    "    #     x_axis = r\"(Easy) $\\leftarrow$ Credit assignment length $\\rightarrow$ (Hard)\"\n",
    "    #     y_label = \"Optimal agent w/o credit assignment\"\n",
    "    #     y_value = 0.0\n",
    "    #     show_legend =  False # True #\n",
    "\n",
    "    df = walk_through(\n",
    "        path,\n",
    "        metric,\n",
    "        query_fn,\n",
    "        start=0,\n",
    "        end=end,\n",
    "        steps=300,\n",
    "        window=10,\n",
    "        cutoff=0.9,\n",
    "    )\n",
    "    df = df.fillna(False)\n",
    "\n",
    "    # custom functions to reduce flags\n",
    "    df[\"seq\"] = df[\"config_seq.model.seq_model_config.name\"].str.upper()\n",
    "    df[\"n_layer\"] = df[\"config_seq.model.seq_model_config.n_layer\"].astype(int)\n",
    "\n",
    "    # take the average of the last 15/300=5% evaluation on that metric, and keep the other metrics as the last row.\n",
    "    final_values = (\n",
    "        df.groupby(\"run_name\")\n",
    "        .apply(lambda x: x.iloc[-1:].assign(**{metric: x[metric].tail(15).mean()}))\n",
    "        .reset_index(drop=True)\n",
    "    )\n",
    "    final_values[x_axis] = env_len\n",
    "    dfs.append(final_values)\n",
    "\n",
    "df = pd.concat(dfs, axis=0, ignore_index=True)\n",
    "\n",
    "if ablation == \"\":\n",
    "    plt.rcParams[\"figure.figsize\"] = (8, 3)\n",
    "    plt.rcParams[\"axes.labelsize\"] = 17\n",
    "    plt.rcParams[\"axes.titlesize\"] = 17\n",
    "    plt.rcParams[\"xtick.labelsize\"] = 17\n",
    "    plt.rcParams[\"ytick.labelsize\"] = 17\n",
    "    plt.rcParams[\"legend.fontsize\"] = 17\n",
    "\n",
    "ans = sns.lineplot(\n",
    "    data=df,\n",
    "    x=x_axis,\n",
    "    y=metric,\n",
    "    palette=\"Dark2\" if ablation == \"\" else \"Set1\",\n",
    "    hue=hue,\n",
    "    hue_order=np.sort(df[hue].unique()) if hue is not None else None,\n",
    "    style=style,\n",
    "    style_order=np.sort(df[style].unique()) if style is not None else None,\n",
    "    markers=True,\n",
    "    dashes=False,\n",
    "    markersize=10,\n",
    ")\n",
    "if \"loss\" in metric:\n",
    "    ans.set_yscale(\"log\")\n",
    "\n",
    "if ablation == \"\":\n",
    "    plt.axhline(y=y_value, label=y_label)\n",
    "\n",
    "if show_legend:\n",
    "    ans.legend(framealpha=0.2)  # must use the returned ans\n",
    "else:\n",
    "    ans.legend().set_visible(False)\n",
    "\n",
    "plt.title(f\"{env_name}\")\n",
    "os.makedirs(\"plts\", exist_ok=True)  # use flattened folder for easy upload in overleaf\n",
    "plt.savefig(\n",
    "    f\"plts/{path.split('/')[-2]}_{ablation}{metric}_{hue}_{style}\"\n",
    "    + (\"_leg\" if show_legend else \"\")\n",
    "    + \".pdf\",\n",
    "    bbox_inches=\"tight\",\n",
    "    pad_inches=0.03,\n",
    ")  # default 0.1\n",
    "plt.show()\n",
    "plt.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "3d153d005c97a27d02bd55058c93c0fb18773b510051e37e91dbf10cc547ca4d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
