{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import sys\n",
    "import os\n",
    "\n",
    "sys.path.append(os.path.expanduser(\"~/.mujoco/mujoco210/bin\"))\n",
    "os.environ[\"LD_LIBRARY_PATH\"] = \":\".join([os.path.expanduser(\"~/.mujoco/mujoco210/bin\"), \"/usr/lib/nvidia/\"])\n",
    "\n",
    "from d4rl import infos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mean</th>\n",
       "      <th>sterr</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>env</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>ant-expert-v2</th>\n",
       "      <td>1.315764</td>\n",
       "      <td>0.010631</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ant-medium-v2</th>\n",
       "      <td>1.055235</td>\n",
       "      <td>0.011786</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ant-random-v2</th>\n",
       "      <td>0.389008</td>\n",
       "      <td>0.015241</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>halfcheetah-expert-v2</th>\n",
       "      <td>0.710352</td>\n",
       "      <td>0.026557</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>halfcheetah-medium-v2</th>\n",
       "      <td>0.632496</td>\n",
       "      <td>0.002517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>halfcheetah-random-v2</th>\n",
       "      <td>0.316693</td>\n",
       "      <td>0.003897</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hopper-expert-v2</th>\n",
       "      <td>1.017680</td>\n",
       "      <td>0.025706</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hopper-medium-v2</th>\n",
       "      <td>0.738426</td>\n",
       "      <td>0.043585</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hopper-random-v2</th>\n",
       "      <td>0.065391</td>\n",
       "      <td>0.000267</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>walker2d-expert-v2</th>\n",
       "      <td>1.083300</td>\n",
       "      <td>0.001100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>walker2d-medium-v2</th>\n",
       "      <td>0.839295</td>\n",
       "      <td>0.003030</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>walker2d-random-v2</th>\n",
       "      <td>0.053028</td>\n",
       "      <td>0.002011</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                           mean     sterr\n",
       "env                                      \n",
       "ant-expert-v2          1.315764  0.010631\n",
       "ant-medium-v2          1.055235  0.011786\n",
       "ant-random-v2          0.389008  0.015241\n",
       "halfcheetah-expert-v2  0.710352  0.026557\n",
       "halfcheetah-medium-v2  0.632496  0.002517\n",
       "halfcheetah-random-v2  0.316693  0.003897\n",
       "hopper-expert-v2       1.017680  0.025706\n",
       "hopper-medium-v2       0.738426  0.043585\n",
       "hopper-random-v2       0.065391  0.000267\n",
       "walker2d-expert-v2     1.083300  0.001100\n",
       "walker2d-medium-v2     0.839295  0.003030\n",
       "walker2d-random-v2     0.053028  0.002011"
      ]
     },
     "execution_count": 109,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Get last step for each run:\n",
    "# last_step = df.groupby([\"run\"])[\"step\"].idxmax()\n",
    "# Find and discard the worst two runs:\n",
    "# df.loc[last_step,].groupby([\"env\", \"use_pop\"]).min()\n",
    "# df.groupby([\"env\", \"use_pop\"]).last()\n",
    "\n",
    "getGroup = lambda df: df.groupby([\"env\"])[\"average_return\"]\n",
    "\n",
    "\n",
    "def processExperiment(fn: str):\n",
    "    df = pd.read_csv(fn)\n",
    "\n",
    "    grouplists = (\n",
    "        getGroup(df).agg(lambda x: list(x.drop(x.nsmallest(2).index))).reset_index(-1)\n",
    "    )\n",
    "\n",
    "    # Remap to [0, 100]:\n",
    "    grouplists[\"average_return\"] = grouplists.apply(\n",
    "        lambda x: [\n",
    "            (y - infos.REF_MIN_SCORE[x[\"env\"]])\n",
    "            / (infos.REF_MAX_SCORE[x[\"env\"]] - infos.REF_MIN_SCORE[x[\"env\"]])\n",
    "            for y in x[\"average_return\"]\n",
    "        ],\n",
    "        axis=1,\n",
    "    )\n",
    "\n",
    "    mean = (\n",
    "        grouplists[\"average_return\"]\n",
    "        .apply(lambda x: np.mean(x))\n",
    "        .reset_index(-1, drop=True)\n",
    "    )\n",
    "    mean.name = \"mean\"\n",
    "    sterr = (\n",
    "        grouplists[\"average_return\"]\n",
    "        .apply(lambda x: np.std(x) / np.sqrt(len(x)))\n",
    "        .reset_index(-1, drop=True)\n",
    "    )\n",
    "    sterr.name = \"sterr\"\n",
    "\n",
    "    df = pd.concat([grouplists[\"env\"], mean, sterr], axis=1)\n",
    "    df.set_index(\"env\", inplace=True)\n",
    "    return df\n",
    "\n",
    "\n",
    "cql = processExperiment(\"cql.csv\")\n",
    "cql\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get all the returns:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [],
   "source": [
    "cql = processExperiment(\"cql.csv\")\n",
    "ddcql = processExperiment(\"dualdicecql.csv\")\n",
    "popcql = processExperiment(\"popcql.csv\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mean</th>\n",
       "      <th>sterr</th>\n",
       "      <th>mean_ddcql</th>\n",
       "      <th>sterr_ddcql</th>\n",
       "      <th>mean_popcql</th>\n",
       "      <th>sterr_popcql</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>env</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>ant-expert-v2</th>\n",
       "      <td>131.6</td>\n",
       "      <td>1.1</td>\n",
       "      <td>126.2</td>\n",
       "      <td>2.4</td>\n",
       "      <td>130.7</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ant-medium-v2</th>\n",
       "      <td>105.5</td>\n",
       "      <td>1.2</td>\n",
       "      <td>103.8</td>\n",
       "      <td>0.3</td>\n",
       "      <td>103.5</td>\n",
       "      <td>0.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ant-random-v2</th>\n",
       "      <td>38.9</td>\n",
       "      <td>1.5</td>\n",
       "      <td>36.5</td>\n",
       "      <td>1.3</td>\n",
       "      <td>43.4</td>\n",
       "      <td>2.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>halfcheetah-expert-v2</th>\n",
       "      <td>71.0</td>\n",
       "      <td>2.7</td>\n",
       "      <td>66.5</td>\n",
       "      <td>3.2</td>\n",
       "      <td>58.7</td>\n",
       "      <td>2.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>halfcheetah-medium-v2</th>\n",
       "      <td>63.2</td>\n",
       "      <td>0.3</td>\n",
       "      <td>53.9</td>\n",
       "      <td>0.2</td>\n",
       "      <td>60.0</td>\n",
       "      <td>1.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>halfcheetah-random-v2</th>\n",
       "      <td>31.7</td>\n",
       "      <td>0.4</td>\n",
       "      <td>26.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>29.8</td>\n",
       "      <td>0.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hopper-expert-v2</th>\n",
       "      <td>101.8</td>\n",
       "      <td>2.6</td>\n",
       "      <td>105.4</td>\n",
       "      <td>1.5</td>\n",
       "      <td>104.2</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hopper-medium-v2</th>\n",
       "      <td>73.8</td>\n",
       "      <td>4.4</td>\n",
       "      <td>70.1</td>\n",
       "      <td>2.7</td>\n",
       "      <td>66.2</td>\n",
       "      <td>3.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hopper-random-v2</th>\n",
       "      <td>6.5</td>\n",
       "      <td>0.0</td>\n",
       "      <td>7.6</td>\n",
       "      <td>0.3</td>\n",
       "      <td>6.5</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>walker2d-expert-v2</th>\n",
       "      <td>108.3</td>\n",
       "      <td>0.1</td>\n",
       "      <td>109.7</td>\n",
       "      <td>0.4</td>\n",
       "      <td>108.0</td>\n",
       "      <td>0.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>walker2d-medium-v2</th>\n",
       "      <td>83.9</td>\n",
       "      <td>0.3</td>\n",
       "      <td>83.3</td>\n",
       "      <td>0.4</td>\n",
       "      <td>83.7</td>\n",
       "      <td>0.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>walker2d-random-v2</th>\n",
       "      <td>5.3</td>\n",
       "      <td>0.2</td>\n",
       "      <td>9.3</td>\n",
       "      <td>0.8</td>\n",
       "      <td>7.1</td>\n",
       "      <td>0.4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                        mean  sterr  mean_ddcql  sterr_ddcql  mean_popcql  \\\n",
       "env                                                                         \n",
       "ant-expert-v2          131.6    1.1       126.2          2.4        130.7   \n",
       "ant-medium-v2          105.5    1.2       103.8          0.3        103.5   \n",
       "ant-random-v2           38.9    1.5        36.5          1.3         43.4   \n",
       "halfcheetah-expert-v2   71.0    2.7        66.5          3.2         58.7   \n",
       "halfcheetah-medium-v2   63.2    0.3        53.9          0.2         60.0   \n",
       "halfcheetah-random-v2   31.7    0.4        26.5          0.2         29.8   \n",
       "hopper-expert-v2       101.8    2.6       105.4          1.5        104.2   \n",
       "hopper-medium-v2        73.8    4.4        70.1          2.7         66.2   \n",
       "hopper-random-v2         6.5    0.0         7.6          0.3          6.5   \n",
       "walker2d-expert-v2     108.3    0.1       109.7          0.4        108.0   \n",
       "walker2d-medium-v2      83.9    0.3        83.3          0.4         83.7   \n",
       "walker2d-random-v2       5.3    0.2         9.3          0.8          7.1   \n",
       "\n",
       "                       sterr_popcql  \n",
       "env                                  \n",
       "ant-expert-v2                   1.0  \n",
       "ant-medium-v2                   0.8  \n",
       "ant-random-v2                   2.2  \n",
       "halfcheetah-expert-v2           2.4  \n",
       "halfcheetah-medium-v2           1.1  \n",
       "halfcheetah-random-v2           0.5  \n",
       "hopper-expert-v2                1.0  \n",
       "hopper-medium-v2                3.6  \n",
       "hopper-random-v2                0.0  \n",
       "walker2d-expert-v2              0.2  \n",
       "walker2d-medium-v2              0.6  \n",
       "walker2d-random-v2              0.4  "
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "table = cql.join(ddcql, lsuffix=\"\", rsuffix=\"_ddcql\").join(popcql, lsuffix=\"\", rsuffix=\"_popcql\")\n",
    "(table * 100).round(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [],
   "source": [
    "def envname(table):\n",
    "    env = table.index\n",
    "    benv = pd.Series(env.str.split(\"-\", expand=True).get_level_values(0), index=env)\n",
    "    benv.name = \"env_name\"\n",
    "    venv = pd.Series(env.str.split(\"-\", expand=True).get_level_values(1), index=env)\n",
    "    venv.name = \"env_variant\"\n",
    "\n",
    "    return pd.concat([benv, venv, table], axis=1)\n",
    "\n",
    "out = envname(table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\multicolumn{2}{l}{Environment} & \\multicolumn{3}{c}{CQL} & \\multicolumn{3}{c}{DualDICE + CQL} & \\multicolumn{3}{c}{POP-CQL}\n",
      "\\\\ \\hline \\texttt{ant         } & \\texttt{expert      } & \\textbf{131.6} & $\\bm\\pm$ & \\textbf{1.1} & 126.2 & $\\pm$ & 2.4 & \\textbf{130.7} & $\\bm\\pm$ & \\textbf{1.0}\n",
      "\\\\  & \\texttt{medium      } & \\textbf{105.5} & $\\bm\\pm$ & \\textbf{1.2} & 103.8 & $\\pm$ & 0.3 & 103.5 & $\\pm$ & 0.8\n",
      "\\\\  & \\texttt{random      } & 38.9 & $\\pm$ & 1.5 & 36.5 & $\\pm$ & 1.3 & \\textbf{43.4} & $\\bm\\pm$ & \\textbf{2.2}\n",
      "\\\\ \\hline \\texttt{halfcheetah } & \\texttt{expert      } & \\textbf{71.0} & $\\bm\\pm$ & \\textbf{2.7} & 66.5 & $\\pm$ & 3.2 & 58.7 & $\\pm$ & 2.4\n",
      "\\\\  & \\texttt{medium      } & \\textbf{63.2} & $\\bm\\pm$ & \\textbf{0.3} & 53.9 & $\\pm$ & 0.2 & 60.0 & $\\pm$ & 1.1\n",
      "\\\\  & \\texttt{random      } & \\textbf{31.7} & $\\bm\\pm$ & \\textbf{0.4} & 26.5 & $\\pm$ & 0.2 & 29.8 & $\\pm$ & 0.5\n",
      "\\\\ \\hline \\texttt{hopper      } & \\texttt{expert      } & 101.8 & $\\pm$ & 2.6 & \\textbf{105.4} & $\\bm\\pm$ & \\textbf{1.5} & \\textbf{104.2} & $\\bm\\pm$ & \\textbf{1.0}\n",
      "\\\\  & \\texttt{medium      } & \\textbf{73.8} & $\\bm\\pm$ & \\textbf{4.4} & \\textbf{70.1} & $\\bm\\pm$ & \\textbf{2.7} & 66.2 & $\\pm$ & 3.6\n",
      "\\\\  & \\texttt{random      } & 6.5 & $\\pm$ & 0.0 & \\textbf{7.6} & $\\bm\\pm$ & \\textbf{0.3} & 6.5 & $\\pm$ & 0.0\n",
      "\\\\ \\hline \\texttt{walker2d    } & \\texttt{expert      } & 108.3 & $\\pm$ & 0.1 & \\textbf{109.7} & $\\bm\\pm$ & \\textbf{0.4} & 108.0 & $\\pm$ & 0.2\n",
      "\\\\  & \\texttt{medium      } & \\textbf{83.9} & $\\bm\\pm$ & \\textbf{0.3} & 83.3 & $\\pm$ & 0.4 & \\textbf{83.7} & $\\bm\\pm$ & \\textbf{0.6}\n",
      "\\\\  & \\texttt{random      } & 5.3 & $\\pm$ & 0.2 & \\textbf{9.3} & $\\bm\\pm$ & \\textbf{0.8} & 7.1 & $\\pm$ & 0.4\n",
      "\\\\ \\hline\\hline\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    r\"\\multicolumn{2}{l}{Environment} & \\multicolumn{3}{c}{CQL} & \\multicolumn{3}{c}{DualDICE + CQL} & \\multicolumn{3}{c}{POP-CQL}\"\n",
    ")\n",
    "\n",
    "curr = None\n",
    "for _, row in out.iterrows():\n",
    "    env = row[\"env_name\"]\n",
    "    envstr = \"\"\n",
    "    if env != curr:\n",
    "        curr = env\n",
    "        envstr = f\"\\\\hline \\\\texttt{{{row['env_name']:12}}}\"\n",
    "\n",
    "    lcb = max(\n",
    "        row[\"mean\"] - row[\"sterr\"],\n",
    "        row[\"mean_ddcql\"] - row[\"sterr_ddcql\"],\n",
    "        row[\"mean_popcql\"] - row[\"sterr_popcql\"],\n",
    "    )\n",
    "\n",
    "    def entry(mn, se, lcb):\n",
    "        if mn >= lcb:\n",
    "            return [f\"\\\\textbf{{{mn*100:.1f}}}\", \"$\\\\bm\\\\pm$\", f\"\\\\textbf{{{se*100:.1f}}}\"]\n",
    "        return [f\"{mn*100:.1f}\", \"$\\\\pm$\", f\"{se*100:.1f}\"]\n",
    "        \n",
    "\n",
    "    o = [\n",
    "        \"\\\\\\\\ \" + envstr,\n",
    "        f\"\\\\texttt{{{row['env_variant']:12}}}\",\n",
    "        *entry(row[\"mean\"] , row[\"sterr\"], lcb),\n",
    "        *entry(row[\"mean_ddcql\"] , row[\"sterr_ddcql\"], lcb),\n",
    "        *entry(row[\"mean_popcql\"] , row[\"sterr_popcql\"], lcb),\n",
    "    ]\n",
    "    print(\" & \".join(o))\n",
    "    # print(row)\n",
    "print(\"\\\\\\\\ \\\\hline\\\\hline\")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Raw scores for appendix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\paragraph{{CQL}}\n",
      "\\begin{tabular}{llrrrrrrrr}\n",
      "\\toprule\n",
      "         &        &  (0) &  (1) &  (2) &  (3) &  (4) &  (5) &  (6) &  (7) \\\\\n",
      "env\\_name & env\\_variant &                      &                      &                      &                      &                      &                      &                      &                      \\\\\n",
      "\\midrule\n",
      "ant & expert &               5129.7 &               5402.0 &               5069.7 &               4740.0 &               5258.5 &               5108.3 &               5091.9 &               5255.1 \\\\\n",
      "         & medium &               3995.3 &               4079.9 &               4311.8 &               4009.5 &               3779.7 &               4242.7 &               3906.9 &               4032.6 \\\\\n",
      "         & random &               1442.6 &                858.8 &                812.1 &               1432.6 &               1127.0 &               1082.4 &               1294.7 &               1482.5 \\\\\n",
      "halfcheetah & expert &              10041.8 &               9179.1 &               6708.5 &               8060.3 &               4842.6 &               8228.0 &               7820.8 &               7903.8 \\\\\n",
      "         & medium &               7538.9 &               7605.4 &               7499.3 &               7432.0 &               7725.4 &               7555.2 &               7510.1 &               6822.8 \\\\\n",
      "         & random &               3658.6 &               3830.1 &               3751.8 &               3628.3 &               -293.7 &               3581.2 &               2940.6 &               3459.7 \\\\\n",
      "hopper & expert &               3152.3 &               2593.1 &               3442.5 &               3499.4 &               3533.7 &               3029.2 &               2516.5 &               3093.8 \\\\\n",
      "         & medium &               2843.1 &               1960.5 &               1854.3 &               2318.6 &               2791.2 &               2405.2 &               1965.2 &               1974.6 \\\\\n",
      "         & random &                191.5 &                195.0 &                190.6 &                187.0 &                193.5 &                184.6 &                189.6 &                195.1 \\\\\n",
      "walker2d & expert &               4973.5 &               4989.7 &               4928.2 &               4989.6 &               4953.0 &               4966.6 &               4974.3 &               4954.6 \\\\\n",
      "         & medium &               3837.3 &               3831.9 &               3890.8 &               3724.2 &               3911.9 &               3834.8 &               3601.2 &               3820.6 \\\\\n",
      "         & random &                227.5 &                232.9 &                284.7 &                264.2 &                213.8 &                241.7 &                169.7 &                219.3 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n",
      "\\paragraph{{DualDICE + CQL}}\n",
      "\\begin{tabular}{llrrrrrrrr}\n",
      "\\toprule\n",
      "         &        &  (0) &  (1) &  (2) &  (3) &  (4) &  (5) &  (6) &  (7) \\\\\n",
      "env\\_name & env\\_variant &                      &                      &                      &                      &                      &                      &                      &                      \\\\\n",
      "\\midrule\n",
      "ant & expert &               4140.7 &               5011.6 &               5198.5 &               4429.6 &               4966.6 &               5285.7 &               4903.6 &               4529.0 \\\\\n",
      "         & medium &               4056.3 &               4076.0 &               3989.8 &               3878.7 &               3796.1 &               4033.4 &               4028.7 &               4056.0 \\\\\n",
      "         & random &               1251.8 &                999.2 &               1013.7 &                672.4 &               1327.5 &               1041.0 &               1287.7 &               1321.9 \\\\\n",
      "halfcheetah & expert &               4790.4 &               9272.3 &               8710.6 &               8479.7 &               6321.2 &               7471.1 &               7605.5 &               5314.2 \\\\\n",
      "         & medium &               6264.5 &               6477.0 &               6435.3 &               6325.1 &               6382.9 &               6319.6 &               6395.0 &               6418.3 \\\\\n",
      "         & random &               2939.0 &               3046.8 &               2920.6 &               3075.1 &               3115.7 &               2934.2 &               2292.2 &               2944.8 \\\\\n",
      "hopper & expert &               3391.7 &               3332.6 &               3486.3 &               3486.4 &               3112.9 &               3557.9 &               3157.2 &               3208.9 \\\\\n",
      "         & medium &               2262.1 &               1542.4 &               2212.6 &               1846.6 &               2018.9 &               2708.5 &               2218.7 &               2144.5 \\\\\n",
      "         & random &                179.2 &                267.6 &                228.9 &                212.3 &                192.1 &                231.0 &                175.4 &                231.7 \\\\\n",
      "walker2d & expert &               4904.9 &               5101.5 &               5028.9 &               4972.0 &               4931.7 &               5000.7 &               5089.0 &               5029.8 \\\\\n",
      "         & medium &               3795.4 &               3848.3 &               3798.3 &               3780.8 &               3730.8 &               3794.1 &               3926.6 &                  3.3 \\\\\n",
      "         & random &                205.4 &                543.4 &                322.2 &                377.6 &                332.5 &                182.4 &                561.6 &                443.5 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n",
      "\\paragraph{{POP-CQL}}\n",
      "\\begin{tabular}{llrrrrrrrr}\n",
      "\\toprule\n",
      "         &        &  (0) &  (1) &  (2) &  (3) &  (4) &  (5) &  (6) &  (7) \\\\\n",
      "env\\_name & env\\_variant &                      &                      &                      &                      &                      &                      &                      &                      \\\\\n",
      "\\midrule\n",
      "ant & expert &               4850.1 &               5084.7 &               5329.2 &               5002.8 &               5213.8 &               4694.8 &               5236.4 &               5159.3 \\\\\n",
      "         & medium &              -2103.7 &               3980.8 &               3921.1 &               4054.9 &               4161.2 &               3945.3 &               3284.0 &               4088.3 \\\\\n",
      "         & random &               1642.2 &               1577.9 &               1007.5 &               1450.3 &               1857.9 &               1182.7 &               1284.9 &               1134.2 \\\\\n",
      "halfcheetah & expert &               7519.3 &               6713.0 &               5647.4 &               6773.7 &               7907.1 &               5576.3 &               7505.9 &               5327.5 \\\\\n",
      "         & medium &               -267.5 &               7363.0 &               7401.7 &               6874.6 &               6536.6 &               7427.0 &               5580.9 &               7430.1 \\\\\n",
      "         & random &               -414.3 &               3216.6 &               3649.7 &               -318.9 &               3607.3 &               3285.6 &               3405.5 &               3342.2 \\\\\n",
      "hopper & expert &               3377.3 &               3097.7 &               3345.7 &               2550.0 &               3251.9 &               3343.3 &               3531.7 &               3375.7 \\\\\n",
      "         & medium &               2227.7 &               1841.8 &               2710.5 &               1155.3 &                 39.9 &               2126.2 &               2006.9 &               1892.0 \\\\\n",
      "         & random &                187.7 &                190.5 &                187.7 &                191.9 &                183.6 &                193.4 &                195.1 &                193.7 \\\\\n",
      "walker2d & expert &               4968.1 &               4984.4 &               4976.1 &               4930.3 &               4929.4 &               4965.7 &               4937.3 &               4929.7 \\\\\n",
      "         & medium &               3637.4 &               3786.4 &               3735.3 &               3939.4 &               3674.0 &               3817.0 &               3859.7 &               3914.3 \\\\\n",
      "         & random &                397.7 &                203.6 &                309.6 &                259.9 &                325.1 &                323.9 &                336.7 &                256.7 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<>:11: DeprecationWarning: invalid escape sequence \\p\n",
      "/tmp/ipykernel_1857648/1777620737.py:7: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  rv = df.to_latex(float_format=\"%.1f\")\n",
      "/tmp/ipykernel_1857648/1777620737.py:7: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  rv = df.to_latex(float_format=\"%.1f\")\n",
      "/tmp/ipykernel_1857648/1777620737.py:7: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  rv = df.to_latex(float_format=\"%.1f\")\n"
     ]
    }
   ],
   "source": [
    "def processAppendix(fn: str, base_seed=0xCAFECAF0):\n",
    "    df = pd.read_csv(fn)\n",
    "    df[\"seed\"] = df[\"seed\"] - base_seed\n",
    "    df = df[['env', 'seed', 'average_return']]\n",
    "    df = df.pivot(index='env', columns='seed')\n",
    "    df = envname(df).reset_index(drop=True).set_index(['env_name', 'env_variant'])\n",
    "    rv = df.to_latex(float_format=\"%.1f\")\n",
    "    return rv.replace(\"average\\\\_return, \", \"\")\n",
    "\n",
    "cql = processAppendix(\"cql.csv\", base_seed=0xCAFE0000)\n",
    "print(\"\\paragraph{{CQL}}\")\n",
    "print(cql)\n",
    "\n",
    "cql = processAppendix(\"dualdicecql.csv\", base_seed=0xCAFECAF0)\n",
    "print(\"\\paragraph{{DualDICE + CQL}}\")\n",
    "print(cql)\n",
    "\n",
    "cql = processAppendix(\"popcql.csv\", base_seed=0xCAFECAF0)\n",
    "print(\"\\paragraph{{POP-CQL}}\")\n",
    "print(cql)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrr}\n",
      "\\toprule\n",
      "{} &  random &     sac \\\\\n",
      "env                   &         &         \\\\\n",
      "\\midrule\n",
      "ant-expert-v2         &  -325.6 &  3879.7 \\\\\n",
      "ant-medium-v2         &  -325.6 &  3879.7 \\\\\n",
      "ant-random-v2         &  -325.6 &  3879.7 \\\\\n",
      "halfcheetah-expert-v2 &  -280.2 & 12135.0 \\\\\n",
      "halfcheetah-medium-v2 &  -280.2 & 12135.0 \\\\\n",
      "halfcheetah-random-v2 &  -280.2 & 12135.0 \\\\\n",
      "hopper-expert-v2      &   -20.3 &  3234.3 \\\\\n",
      "hopper-medium-v2      &   -20.3 &  3234.3 \\\\\n",
      "hopper-random-v2      &   -20.3 &  3234.3 \\\\\n",
      "walker2d-expert-v2    &     1.6 &  4592.3 \\\\\n",
      "walker2d-medium-v2    &     1.6 &  4592.3 \\\\\n",
      "walker2d-random-v2    &     1.6 &  4592.3 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1857648/2382787170.py:6: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n",
      "  print(pd.concat([random, sac], axis=1).to_latex(float_format=\"%.1f\"))\n"
     ]
    }
   ],
   "source": [
    "envs = pd.Series(table.index)\n",
    "random = pd.Series([infos.REF_MIN_SCORE[x] for x in envs], index=envs)\n",
    "random.name = \"random\"\n",
    "sac = pd.Series([infos.REF_MAX_SCORE[x] for x in envs], index=envs)\n",
    "sac.name = \"sac\"\n",
    "print(pd.concat([random, sac], axis=1).to_latex(float_format=\"%.1f\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jaxcql",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "8b6a44bf000437e08e8b98742e0fcfe22ebd81f25901216ee21bd9c9f5aa71f6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
