{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-15T09:07:30.097808Z",
     "start_time": "2025-05-15T09:07:30.093486Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "\n",
    "import os\n",
    "from utils.vis_tool import walk_through\n",
    "from itertools import product\n",
    "\n",
    "\n",
    "def reduce_shared_encoder(shared_encoder, freeze_critic, freeze_all):\n",
    "    if not shared_encoder:\n",
    "        return \"sep\"\n",
    "    if freeze_critic:\n",
    "        return \"sha-freeze\"\n",
    "    elif freeze_all:\n",
    "        return \"sha-freeze-all\"\n",
    "    else:\n",
    "        return \"sha-naive\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-15T09:07:33.179161Z",
     "start_time": "2025-05-15T09:07:33.165968Z"
    }
   },
   "outputs": [],
   "source": [
    "mpl.rcParams.update(mpl.rcParamsDefault)\n",
    "sns.set(style='whitegrid')\n",
    "plt.rcParams[\"figure.dpi\"] = 300\n",
    "plt.rcParams[\"figure.figsize\"] = (9, 2)\n",
    "plt.rcParams[\"axes.labelsize\"] = 10\n",
    "plt.rcParams[\"axes.titlesize\"] = 10\n",
    "plt.rcParams[\"xtick.labelsize\"] = 10\n",
    "plt.rcParams[\"ytick.labelsize\"] = 10\n",
    "plt.rcParams[\"legend.fontsize\"] = 10\n",
    "plt.rcParams[\"axes.grid\"] = True\n",
    "plt.rcParams[\"legend.loc\"] = \"best\"\n",
    "plt.rcParams[\"lines.linewidth\"] = 1.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-15T09:12:14.763592Z",
     "start_time": "2025-05-15T09:11:57.708877Z"
    }
   },
   "outputs": [],
   "source": [
    "metric = \"return\"\n",
    "separate_legend = False\n",
    "\n",
    "palette = {\n",
    "    'LRU': 'red',\n",
    "    'LSTM': 'blue',\n",
    "    'GPT': 'green',\n",
    "    'Ours': 'orange',\n",
    "}\n",
    "\n",
    "hue = (\"seq\")\n",
    "style = (None)\n",
    "\n",
    "def query_fn(flags):\n",
    "    if flags[\"config_seq\"][\"model\"][\"seq_model_config\"][\"name\"] == 'lifgate':\n",
    "        return True\n",
    "    return True\n",
    "\n",
    "first_col = [0]\n",
    "last_row = [0, 1, 2, 3]\n",
    "\n",
    "\n",
    "for idx, (env_type, env) in enumerate(product([\"p\"], [\"ant\", \"cheetah\", \"hopper\",\"walker\"])):\n",
    "    plt.subplot(1, 4, idx+1)\n",
    "    base_path = \"logs/pomdp/bullet\"\n",
    "    path = base_path + f\"_{env_type}/{env}\"\n",
    "\n",
    "    end = 1.5e6\n",
    "    df = walk_through(\n",
    "        path,\n",
    "        metric,\n",
    "        query_fn,\n",
    "        start=0,\n",
    "        end=end,\n",
    "        steps=100,\n",
    "        window=10,\n",
    "    )\n",
    "    df = df.fillna(False)\n",
    "\n",
    "    # custom functions to reduce flags\n",
    "    df[\"encoder\"] = df.apply(\n",
    "        lambda row: reduce_shared_encoder(\n",
    "            row[\"shared_encoder\"], row[\"freeze_critic\"], row[\"freeze_all\"]\n",
    "        ),\n",
    "        axis=1,\n",
    "    )\n",
    "    df[\"seq\"] = df[\"config_seq.model.seq_model_config.name\"].str.upper()\n",
    "    df[\"seq\"] = df[\"seq\"].replace('LIFGATE', 'Ours')\n",
    "    \n",
    "    ans = sns.lineplot(\n",
    "        data=df,\n",
    "        x=\"env_steps\",\n",
    "        y=metric,\n",
    "        palette=palette,\n",
    "        hue=hue,\n",
    "        hue_order=np.sort(df[hue].unique()) if hue is not None else None,\n",
    "        style=style,\n",
    "        style_order=np.sort(df[style].unique()) if style is not None else None,\n",
    "    )\n",
    "    \n",
    "    if idx not in first_col:\n",
    "        plt.ylabel('')\n",
    "    \n",
    "    if idx not in last_row:\n",
    "        plt.xlabel('')\n",
    "        \n",
    "    if idx in last_row:\n",
    "        plt.xticks([0, 0.5e6, 1e6, 1.5e6])\n",
    "    else:\n",
    "        plt.xticks([0, 0.5e6, 1e6, 1.5e6], labels=['', '', '', ''])\n",
    "\n",
    "    ans.legend().set_visible(False)\n",
    "\n",
    "    plt.xlim(0, end)\n",
    "    plt.tick_params(pad=0)\n",
    "    plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0, 0), )  # default [-5, 6]\n",
    "    plt.title(f\"{env.title()}-{env_type.title()}\")\n",
    "\n",
    "ans.legend(ncols=6, bbox_to_anchor=(-0., 1.4))\n",
    "\n",
    "os.makedirs(\"plts\", exist_ok=True)\n",
    "plt.savefig(\n",
    "    \"plts/\"\n",
    "    + path.replace(\"logs/\", \"\").replace(\"/\", \"-\")\n",
    "    + f\"_{metric}_{hue}_{style}\"\n",
    "    + (\"\" if separate_legend else \"_leg\")\n",
    "    + \".pdf\",\n",
    "    bbox_inches=\"tight\",\n",
    "    pad_inches=0.03,\n",
    ")  # default 0.1\n",
    "plt.show()\n",
    "plt.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.5 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3d153d005c97a27d02bd55058c93c0fb18773b510051e37e91dbf10cc547ca4d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
