{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-22T13:28:23.510368Z",
     "start_time": "2025-05-22T13:28:23.282678Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "\n",
    "import os\n",
    "from utils.vis_tool import walk_through\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning Curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-22T13:29:50.913902Z",
     "start_time": "2025-05-22T13:29:42.525213Z"
    }
   },
   "outputs": [],
   "source": [
    "mpl.rcParams.update(mpl.rcParamsDefault)\n",
    "plt.rcParams[\"figure.dpi\"] = 300\n",
    "plt.rcParams[\"figure.figsize\"] = (9, 3)\n",
    "plt.rcParams[\"axes.labelsize\"] = 15\n",
    "plt.rcParams[\"axes.titlesize\"] = 12\n",
    "plt.rcParams[\"xtick.labelsize\"] = 14\n",
    "plt.rcParams[\"ytick.labelsize\"] = 14\n",
    "plt.rcParams[\"legend.fontsize\"] = 8  # 10\n",
    "plt.rcParams[\"lines.linewidth\"] = 1.5\n",
    "sns.set(style='whitegrid')\n",
    "\n",
    "palette = {\n",
    "    'LRU': 'red',\n",
    "    'LSTM': 'blue',\n",
    "    'GPT': 'green',\n",
    "    'Ours': 'orange',\n",
    "    'FFM': 'cyan',\n",
    "}\n",
    "\n",
    "ablation = \"\"  # \"gpt\" # ,\"lstm\" #\n",
    "hue = (\"seq\")\n",
    "style = None\n",
    "\n",
    "def query_fn(flags):\n",
    "    if flags[\"config_seq\"][\"model\"][\"seq_model_config\"][\"name\"] == 'gpt' and flags[\"config_seq\"][\"model\"][\"seq_model_config\"][\"n_layer\"] == 1:\n",
    "        return True\n",
    "    if flags[\"config_seq\"][\"model\"][\"seq_model_config\"][\"name\"] == 'lstm' and flags[\"config_seq\"][\"model\"][\"seq_model_config\"][\"n_layer\"] == 1:\n",
    "        return True\n",
    "    if flags[\"config_seq\"][\"model\"][\"seq_model_config\"][\"name\"] == 'lifgate':\n",
    "        return True\n",
    "    if flags[\"config_seq\"][\"model\"][\"seq_model_config\"][\"name\"] == 'lru':\n",
    "        return True\n",
    "    if flags[\"config_seq\"][\"model\"][\"seq_model_config\"][\"name\"] == 'ffm':\n",
    "        return True\n",
    "    return False\n",
    "\n",
    "for idx, (env_len, end) in enumerate(zip([250, 500], [1e6, 1e6])):\n",
    "    plt.subplot(1, 2, idx+1)\n",
    "    env_name = \"Visual Match\"\n",
    "    path = f\"logs/visual_match/{env_len}\"\n",
    "    title_tag = \"memory length\"\n",
    "    metric =   \"success\"\n",
    "    # metric =   \"return\"\n",
    "    \n",
    "    df = walk_through(\n",
    "        path,\n",
    "        metric,\n",
    "        query_fn,\n",
    "        start=0,\n",
    "        end=end,\n",
    "        steps=100,\n",
    "        window=5,\n",
    "    )\n",
    "    df = df.fillna(False)\n",
    "\n",
    "    # custom functions to reduce flags\n",
    "    df[\"seq\"] = df[\"config_seq.model.seq_model_config.name\"].str.upper()\n",
    "    df[\"seq\"] = df[\"seq\"].replace('LIFGATE', 'Ours')\n",
    "    if \"config_seq.model.seq_model_config.n_layer\" in df:\n",
    "        df[\"n_layer\"] = df[\"config_seq.model.seq_model_config.n_layer\"]\n",
    "    \n",
    "    df['run_name'] = df['run_name'].str[:13]\n",
    "    \n",
    "    ans = sns.lineplot(\n",
    "        data=df,\n",
    "        x=\"env_steps\",\n",
    "        y=metric,\n",
    "        palette=palette,\n",
    "        hue=hue,\n",
    "        hue_order=np.sort(df[hue].unique()) if hue is not None else None,\n",
    "        style=style,\n",
    "        style_order=np.sort(df[style].unique()) if style is not None else None,\n",
    "    )\n",
    "\n",
    "    ans.legend().set_visible(False)\n",
    "    plt.xlim(0, end)\n",
    "    plt.ylim(-0.05, 1.05)\n",
    "    plt.axhline(y=1/3)\n",
    "    plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0, 0))  # default [-5, 6]\n",
    "    plt.title(f\"{env_name} w/ {title_tag} of {env_len}\")\n",
    "    \n",
    "ans.legend(ncols=6, bbox_to_anchor=(0.7, 1.3))\n",
    "\n",
    "os.makedirs(\n",
    "    \"plts\", exist_ok=True\n",
    ")\n",
    "plt.savefig(\n",
    "    f\"plts/{path.split('/')[-2]}-{env_len}_{metric}_{ablation}_{hue}_{style}.pdf\"\n",
    "    bbox_inches=\"tight\",\n",
    "    pad_inches=0.03,\n",
    ")  # default 0.1\n",
    "plt.show()\n",
    "plt.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "3d153d005c97a27d02bd55058c93c0fb18773b510051e37e91dbf10cc547ca4d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
